Skip to content

Commit

Permalink
fixes to mkl for large sizes (gwastro#4583)
Browse files Browse the repository at this point in the history
* prototype fixes

* Update mkl.py

* ws

* also update class interface

* typo
  • Loading branch information
ahnitz authored and bhooshan-gadre committed Dec 19, 2023
1 parent 47d5275 commit 94543c3
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions pycbc/fft/mkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@
DFTI_PERM_FORMAT = 56
DFTI_CCE_FORMAT = 57

mkl_prec = {'single': DFTI_SINGLE,
'double': DFTI_DOUBLE,
}

mkl_domain = {'real': {'complex': DFTI_REAL},
'complex': {'real': DFTI_REAL,
'complex':DFTI_COMPLEX,
}
}

mkl_descriptor = {'single': lib.DftiCreateDescriptor_s_1d,
'double': lib.DftiCreateDescriptor_d_1d,
}

def check_status(status):
""" Check the status of a mkl functions and raise a python exeption if
there is an error.
Expand All @@ -76,15 +76,14 @@ def check_status(status):
def create_descriptor(size, idtype, odtype, inplace):
invec = zeros(1, dtype=idtype)
outvec = zeros(1, dtype=odtype)

desc = ctypes.c_void_p(1)
f = lib.DftiCreateDescriptor
f.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int]

prec = mkl_prec[invec.precision]
domain = mkl_domain[str(invec.kind)][str(outvec.kind)]
f = mkl_descriptor[invec.precision]
f.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_long]

status = f(ctypes.byref(desc), domain, size)

status = f(ctypes.byref(desc), prec, domain, 1, size)
if inplace:
lib.DftiSetValue(desc, DFTI_PLACEMENT, DFTI_INPLACE)
else:
Expand Down Expand Up @@ -120,15 +119,14 @@ def ifft(invec, outvec, prec, itype, otype):

# Class based API

_create_descr = lib.DftiCreateDescriptor
_create_descr.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int]

def _get_desc(fftobj):
desc = ctypes.c_void_p(1)
prec = mkl_prec[fftobj.invec.precision]
domain = mkl_domain[str(fftobj.invec.kind)][str(fftobj.outvec.kind)]
status = _create_descr(ctypes.byref(desc), prec, domain,
1, int(fftobj.size))

f = mkl_descriptor[fftobj.invec.precision]
f.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_long]
status = f(ctypes.byref(desc), domain, int(fftobj.size))

check_status(status)
# Now we set various things depending on exactly what kind of transform we're
# performing.
Expand Down

0 comments on commit 94543c3

Please sign in to comment.