-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
(update) Simplified macros and mixed precision support. Also improved…
… docs with more links to connect each old routine with the new ones
- Loading branch information
1 parent
a8e91ff
commit 76daa56
Showing
117 changed files
with
1,291 additions
and
1,275 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,177 +1,257 @@ | ||
#:mute | ||
#:set REAL_TYPE = 'real(wp)' | ||
#:set COMPLEX_TYPE = 'complex(wp)' | ||
|
||
#:set REAL = 'real(wp)' | ||
#:set COMPLEX = 'complex(wp)' | ||
#:set PREFIX = { & | ||
's': { 'type': 'real(wp)', 'wp': 'REAL32'}, & | ||
'd': { 'type': 'real(wp)', 'wp': 'REAL64'}, & | ||
'c': { 'type': 'complex(wp)', 'wp': 'REAL32'}, & | ||
'z': { 'type': 'complex(wp)', 'wp': 'REAL64'}, & | ||
} | ||
|
||
#:set ERROR = lambda pfx: { 'type': f'error: {pfx}', 'wp' : f'error: {pfx}' } | ||
|
||
#:set mix = lambda l, r: list(lp + rp for lp, rp in zip(l,r)) | ||
#:set split = lambda pfx: list(pfx) if len(pfx) > 1 else pfx | ||
#:set get_types = lambda pfxs: (pfxs[0], pfxs[0] if len(pfxs) == 1 else pfxs[1]) | ||
#:set get = lambda pfx,what: PREFIX.get(pfx).get(what) | ||
#:set prefix = lambda pfx, name: name.replace('?',pfx) | ||
#:set kind = lambda pfx: get(pfx,'wp') | ||
#:set type = lambda pfx: get(pfx,'type').replace('wp',kind(pfx)) | ||
#:set real = lambda pfx: REAL.replace('wp',kind(pfx)) | ||
#:set complex= lambda pfx: COMPLEX.replace('wp',kind(pfx)) | ||
|
||
#:set SINGLE_TYPES = ['s','c'] | ||
#:set DOUBLE_TYPES = ['d','z'] | ||
#:set REAL_TYPES = ['s','d'] | ||
#:set COMPLEX_TYPES = ['c','z'] | ||
#:set DEFAULT_TYPES = REAL_TYPES + COMPLEX_TYPES | ||
#:set REAL_COMPLEX_TYPES = ['sc','dz'] | ||
#:set COMPLEX_REAL_TYPES = ['cs','zd'] | ||
|
||
#! Function that handles mixed types conventions | ||
#:set MIX = lambda when, use, pfx: & | ||
use[when.index(pfx)] if pfx in when else pfx | ||
#:set MIX_REAL_COMPLEX = mix(REAL_TYPES,COMPLEX_TYPES) | ||
#:set MIX_COMPLEX_REAL = mix(COMPLEX_TYPES,REAL_TYPES) | ||
#:set MIX_SINGLE_DOUBLE = mix(SINGLE_TYPES,DOUBLE_TYPES) | ||
#:set MIX_DOUBLE_SINGLE = mix(DOUBLE_TYPES,SINGLE_TYPES) | ||
|
||
#:set MIX_REAL_COMPLEX = lambda pfx: MIX(COMPLEX_TYPES,REAL_COMPLEX_TYPES,pfx) | ||
#:set MIX_COMPLEX_REAL = lambda pfx: MIX(COMPLEX_TYPES,COMPLEX_REAL_TYPES,pfx) | ||
${type('s')}$ :: variable | ||
${get('s','wp')}$ | ||
${prefix('s','?gemm')}$ | ||
|
||
#:set PREFIX_TO_TYPE={ & | ||
's': REAL_TYPE, & | ||
'd': REAL_TYPE, & | ||
'c': COMPLEX_TYPE, & | ||
'z': COMPLEX_TYPE, & | ||
} | ||
${mix(REAL_TYPES,COMPLEX_TYPES)}$ | ||
${mix(COMPLEX_TYPES,REAL_TYPES)}$ | ||
${mix(SINGLE_TYPES,DOUBLE_TYPES)}$ | ||
${mix(DOUBLE_TYPES,SINGLE_TYPES)}$ | ||
|
||
#:set PREFIX_TO_KIND={& | ||
's': 'REAL32', & | ||
'd': 'REAL64', & | ||
'c': 'REAL32', & | ||
'z': 'REAL64', & | ||
} | ||
${list(map(split, mix(REAL_TYPES, COMPLEX_TYPES)))}$ | ||
${list(map(split, mix(COMPLEX_TYPES, REAL_TYPES)))}$ | ||
${list(map(split, mix(SINGLE_TYPES,DOUBLE_TYPES)))}$ | ||
${list(map(split, mix(DOUBLE_TYPES,SINGLE_TYPES)))}$ | ||
|
||
#:set TYPE_AND_KIND_TO_PREFIX = { & | ||
'real(REAL32)': 's', & | ||
'real(REAL64)': 'd', & | ||
'complex(REAL32)': 'c', & | ||
'complex(REAL64)': 'z', & | ||
} | ||
#:def timeit(message, code) | ||
block | ||
real :: t1, t2 | ||
call cpu_time(t1) | ||
$:code | ||
call cpu_time(t2) | ||
print '(A," (",G0,"s)")', ${message}$, t2-t1 | ||
end block | ||
#:enddef | ||
|
||
#! Defines a optional variable, creating local corresponding variable by default | ||
#:def optional(dtype, intent, *args) | ||
#:for variable in args | ||
${dtype}$, intent(${intent}$), optional :: ${variable}$ | ||
${dtype}$ :: local_${variable}$ | ||
#:endfor | ||
#:def random_number(type, name, shape='') | ||
#:if type.startswith('complex') | ||
$:random_complex(type, name,shape) | ||
#:else | ||
call random_number(${name}$) | ||
#:endif | ||
#:enddef | ||
|
||
#! Handles a value of "variable" depending on "condition" | ||
#:def optval(condition, variable, true_value, false_value) | ||
if (${condition}$) then | ||
${variable}$ = ${true_value}$ | ||
else | ||
${variable}$ = ${false_value}$ | ||
end if | ||
#:def random_complex(type, name, shape='') | ||
#:set REAL = type.replace('complex','real') | ||
block | ||
${REAL}$ :: re${shape}$ | ||
${REAL}$ :: im${shape}$ | ||
call random_number(im) | ||
call random_number(re) | ||
${name}$ = cmplx(re,im) | ||
end block | ||
#:enddef | ||
|
||
#! Handles default values of the optional | ||
#:def defaults(**kwargs) | ||
#:for variable, default in kwargs.items() | ||
if (present(${variable}$)) then | ||
local_${variable}$ = ${variable}$ | ||
else | ||
local_${variable}$ = ${default}$ | ||
end if | ||
#! Handles parameters (usage: working precision) | ||
#:def parameter(dtype, **kwargs) | ||
#:for variable, value in kwargs.items() | ||
${dtype}$, parameter :: ${variable}$ = ${value}$ | ||
#:endfor | ||
#:enddef | ||
|
||
#! Handles importing and setting precision constants in interfaces | ||
#:def imports(pfxs) | ||
#:set wps = set(list(map(kind, pfxs))) | ||
#:if len(wps) > 1 | ||
import :: ${', '.join(wps)}$ | ||
#:else | ||
import :: ${''.join(wps)}$ | ||
#:endif | ||
#:enddef | ||
|
||
#! Handles the input/output arguments | ||
#:def args(dtype, intent, *args) | ||
#:for variable in args | ||
${dtype}$, intent(${intent}$) :: ${variable}$ | ||
#:endfor | ||
#:enddef | ||
|
||
#! Handles parameters (usage: working precision) | ||
#:def parameter(dtype, **kwargs) | ||
#:for variable, value in kwargs.items() | ||
${dtype}$, parameter :: ${variable}$ = ${value}$ | ||
#! Defines a optional variable, creating local corresponding variable by default | ||
#:def optional(dtype, intent, *args) | ||
#:for variable in args | ||
${dtype}$, intent(${intent}$), optional :: ${variable}$ | ||
${dtype}$ :: local_${variable}$ | ||
#:endfor | ||
#:enddef | ||
|
||
#! Handles the implementation of the modern interface to each supported type and kind | ||
#:def mfi_implement(name, supports, code, f=lambda x: x) | ||
#:for PREFIX in supports | ||
#:set MFI_NAME = "mfi_" + name.replace('?',f(PREFIX)) | ||
#:set F77_NAME = name.replace('?',f(PREFIX)) | ||
#:set TYPE = PREFIX_TO_TYPE.get(PREFIX,None) | ||
#:set KIND = PREFIX_TO_KIND.get(PREFIX,None) | ||
$:code(MFI_NAME,F77_NAME,TYPE,KIND,PREFIX) | ||
#! Handles default values of a optional variable | ||
#:def defaults(**kwargs) | ||
#:for variable, default in kwargs.items() | ||
if (present(${variable}$)) then | ||
local_${variable}$ = ${variable}$ | ||
else | ||
local_${variable}$ = ${default}$ | ||
end if | ||
#:endfor | ||
#:enddef | ||
|
||
#! Define mfi interfaces to implemented routines | ||
#:def mfi_interface(name, types, f=lambda x: x) | ||
interface mfi_${name.replace('?','')}$ | ||
#:for T in types | ||
module procedure mfi_${name.replace('?',f(T))}$ | ||
#:endfor | ||
end interface | ||
#! Handles a value of "variable" depending on "condition" | ||
#:def optval(condition, variable, true_value, false_value) | ||
if (${condition}$) then | ||
${variable}$ = ${true_value}$ | ||
else | ||
${variable}$ = ${false_value}$ | ||
end if | ||
#:enddef | ||
|
||
#! Define f77 interfaces to implemented routines | ||
#:def f77_interface_improved(name, types, f=lambda x: x) | ||
interface f77_${name.replace('?','')}$ | ||
#:for T in types | ||
procedure :: ${name.replace('?',f(T))}$ | ||
#:def interface(functions, procedure='procedure', name='') | ||
interface ${name}$ | ||
#:for function_name in functions | ||
${procedure}$ :: ${function_name}$ | ||
#:endfor | ||
end interface | ||
#:enddef | ||
|
||
#! Define a f77 interfaces to the external blas/lapack library | ||
#:def f77_interface(name, supports, code, f=lambda x: x, improved_f77=True) | ||
|
||
#! Interfaces for the original f77 routines | ||
#! code must implement a routine interface | ||
#:def f77_original(generic_name, prefixes, code) | ||
#:set mfi = 'mfi_' + prefix('',generic_name) | ||
#:set f77 = 'f77_' + prefix('',generic_name) | ||
!> ${generic_name}$ supports ${', '.join(prefixes)}$. | ||
!> See also: [[${mfi}$]], [[${f77}$]]. | ||
interface | ||
#:for PREFIX in supports | ||
#:set NAME = name.replace('?',f(PREFIX)) | ||
#:set TYPE = PREFIX_TO_TYPE.get(PREFIX,None) | ||
#:set KIND = PREFIX_TO_KIND.get(PREFIX,None) | ||
$:code(NAME,TYPE,KIND,PREFIX) | ||
#:for pfx in prefixes | ||
#:set name = prefix(pfx,generic_name) | ||
#:set pfxs = list(map(split,pfx)) | ||
$:code(name,pfxs) | ||
#:endfor | ||
end interface | ||
#:enddef | ||
|
||
#:if improved_f77 | ||
$:f77_interface_improved(name, supports, f=f) | ||
#:endif | ||
#! Define a common interface with the original f77 interfaces | ||
#! So you can call the original function without the prefix | ||
#:def f77_improved(generic_name, prefixes) | ||
#:set functions = map(lambda pfx: prefix(pfx,generic_name), prefixes) | ||
$:interface(functions, name=f"f77_{prefix('',generic_name)}") | ||
#:enddef | ||
|
||
#! In case of missing functions / extensions you can pass a code | ||
#! in which case must provide the routine implementation | ||
#! Must be called inside a contains block | ||
#:def f77_implement(generic_name, prefixes, code) | ||
#:for pfx in prefixes | ||
#:set name = prefix(pfx,generic_name) | ||
#:set pfxs = list(map(split,pfx)) | ||
$:code(name,pfxs) | ||
#:endfor | ||
#:enddef | ||
|
||
#:def mfi_interface(generic_name, prefixes) | ||
#:set functions = map(lambda pfx: 'mfi_' + prefix(pfx,generic_name), prefixes) | ||
$:interface(functions, & | ||
procedure='module procedure', & | ||
name=f"mfi_{prefix('',generic_name)}") | ||
#:enddef | ||
|
||
#! Implements a f77 function / extension | ||
#:def f77_implement(name, supports, code) | ||
#:for PREFIX in supports | ||
#:set NAME = name.replace('?',PREFIX) | ||
#:set TYPE = PREFIX_TO_TYPE.get(PREFIX,None) | ||
#:set KIND = PREFIX_TO_KIND.get(PREFIX,None) | ||
$:code(NAME,TYPE,KIND,PREFIX) | ||
#! Implements the modern interface in code | ||
#! for each supported prefix combination | ||
#! Must be called inside a contains block | ||
#:def mfi_implement(generic_name, prefixes, code) | ||
#:for pfx in prefixes | ||
#:set mfi_name = 'mfi_' + prefix(pfx,generic_name) | ||
#:set f77_name = prefix(pfx,generic_name) | ||
#:set pfxs = list(map(split,pfx)) | ||
$:code(mfi_name,f77_name,pfxs) | ||
#:endfor | ||
#:enddef | ||
|
||
#! Implements a test | ||
#:def test_implement(name, supports, code, f=lambda x: x) | ||
#:for PREFIX in supports | ||
#:set ORIGINAL = name.replace('?',f(PREFIX)) | ||
#:set IMPROVED = "f77_" + name.replace('?','') | ||
#:set MODERN = "mfi_" + name.replace('?','') | ||
#:set TYPE = PREFIX_TO_TYPE.get(PREFIX,None) | ||
#:set KIND = PREFIX_TO_KIND.get(PREFIX,None) | ||
$:code(ORIGINAL,IMPROVED,MODERN,TYPE,KIND,PREFIX) | ||
|
||
#! Implements the test for all interfaces | ||
#! and each supported prefix combination | ||
#! Must be called inside a contains block | ||
#:def test_implement(generic_name, prefixes, code) | ||
#:for pfx in prefixes | ||
#:set f77 = prefix(pfx,generic_name) | ||
#:set f90 = 'f77_' + prefix('',generic_name) | ||
#:set mfi = 'mfi_' + prefix('',generic_name) | ||
#:set pfxs = list(map(split,pfx)) | ||
$:code(f77,f90,mfi,pfxs) | ||
#:endfor | ||
#:enddef | ||
|
||
#! Call the subroutine test | ||
#:def test_run(name, supports, f=lambda x: x) | ||
#:for PREFIX in supports | ||
#:set ORIGINAL = name.replace('?',f(PREFIX)) | ||
#:set MODERN = "mfi_" + name.replace('?','') | ||
@:timeit("testing ${MODERN}$ against ${ORIGINAL}$", { call test_${ORIGINAL}$ }) | ||
#:def test_run(generic_name, prefixes) | ||
#:for pfx in prefixes | ||
#:set f77 = prefix(pfx,generic_name) | ||
#:set mfi = 'mfi_' + prefix('',generic_name) | ||
@:timeit("testing ${mfi}$ against ${f77}$", { call test_${f77}$ }) | ||
#:endfor | ||
#:enddef | ||
|
||
#:def timeit(message, code) | ||
block | ||
real :: t1, t2 | ||
call cpu_time(t1) | ||
$:code | ||
call cpu_time(t2) | ||
print '(A," (",G0,"s)")', ${message}$, t2-t1 | ||
end block | ||
#:def rot_f77(name,pfxs) | ||
#:set A = pfxs[0] | ||
#:set B = A if len(pfxs) == 1 else pfxs[1] | ||
!> ${name.upper()}$ applies a plane rotation. | ||
pure subroutine ${name}$(n, x, incx, y, incy, c, s) | ||
$:imports(pfxs) | ||
@:args(${type(A)}$, in, x(*), y(*)) | ||
@:args(${real(A)}$, in, c) | ||
@:args(${type(B)}$, in, s) | ||
integer, intent(in) :: n, incx, incy | ||
end subroutine | ||
#:enddef | ||
|
||
#:def random_complex(name, shape='') | ||
block | ||
${REAL_TYPE}$ :: re${shape}$ | ||
${REAL_TYPE}$ :: im${shape}$ | ||
call random_number(im) | ||
call random_number(re) | ||
${name}$ = cmplx(re,im) | ||
end block | ||
#:def rot_mfi(mfi_name,f77_name,pfxs) | ||
#:set A = pfxs[0] | ||
#:set B = A if len(pfxs) == 1 else pfxs[1] | ||
!> Given two vectors x and y, | ||
!> each vector element of these vectors is replaced as follows: | ||
!>```fortran | ||
#:if type(A) == real(A) | ||
!> xi = c*xi + s*yi | ||
!> yi = c*yi - s*xi | ||
#:elif type(A) == complex(A) | ||
!> xi = c*xi + s*yi | ||
!> yi = c*yi - conj(s)*xi | ||
#:endif | ||
!>``` | ||
pure subroutine ${mfi_name}$(x, y, c, s, incx, incy) | ||
@:args(${type(A)}$, inout, x(:), y(:)) | ||
@:args(${real(A)}$, in, c) | ||
@:args(${type(B)}$, in, s) | ||
@:optional(integer, in, incx, incy) | ||
integer :: n | ||
@:defaults(incx=1, incy=1) | ||
n = size(x) | ||
call ${f77_name}$(n,x,local_incx,y,local_incy,c,s) | ||
end subroutine | ||
#:enddef | ||
|
||
$:f77_original('?rot', DEFAULT_TYPES + mix(COMPLEX_TYPES,REAL_TYPES), rot_f77) | ||
$:mfi_interface('?rot', DEFAULT_TYPES + mix(COMPLEX_TYPES,REAL_TYPES)) | ||
$:mfi_implement('?rot', DEFAULT_TYPES + mix(COMPLEX_TYPES,REAL_TYPES), rot_mfi) | ||
|
||
|
||
#:endmute |
Oops, something went wrong.