From 94543c38feff8877fbbfecba8f533ad39bfffab0 Mon Sep 17 00:00:00 2001 From: Alex Nitz Date: Thu, 14 Dec 2023 14:18:21 -0500 Subject: [PATCH] fixes to mkl for large sizes (#4583) * prototype fixes * Update mkl.py * ws * also update class interface * typo --- pycbc/fft/mkl.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/pycbc/fft/mkl.py b/pycbc/fft/mkl.py index 6bd75ee0c9c..08bcf72699b 100644 --- a/pycbc/fft/mkl.py +++ b/pycbc/fft/mkl.py @@ -54,16 +54,16 @@ DFTI_PERM_FORMAT = 56 DFTI_CCE_FORMAT = 57 -mkl_prec = {'single': DFTI_SINGLE, - 'double': DFTI_DOUBLE, - } - mkl_domain = {'real': {'complex': DFTI_REAL}, 'complex': {'real': DFTI_REAL, 'complex':DFTI_COMPLEX, } } +mkl_descriptor = {'single': lib.DftiCreateDescriptor_s_1d, + 'double': lib.DftiCreateDescriptor_d_1d, + } + def check_status(status): """ Check the status of a mkl functions and raise a python exeption if there is an error. @@ -76,15 +76,14 @@ def check_status(status): def create_descriptor(size, idtype, odtype, inplace): invec = zeros(1, dtype=idtype) outvec = zeros(1, dtype=odtype) - desc = ctypes.c_void_p(1) - f = lib.DftiCreateDescriptor - f.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int] - prec = mkl_prec[invec.precision] domain = mkl_domain[str(invec.kind)][str(outvec.kind)] + f = mkl_descriptor[invec.precision] + f.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_long] + + status = f(ctypes.byref(desc), domain, size) - status = f(ctypes.byref(desc), prec, domain, 1, size) if inplace: lib.DftiSetValue(desc, DFTI_PLACEMENT, DFTI_INPLACE) else: @@ -120,15 +119,14 @@ def ifft(invec, outvec, prec, itype, otype): # Class based API -_create_descr = lib.DftiCreateDescriptor -_create_descr.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int] - def _get_desc(fftobj): desc = ctypes.c_void_p(1) - prec = mkl_prec[fftobj.invec.precision] domain = mkl_domain[str(fftobj.invec.kind)][str(fftobj.outvec.kind)] - status = _create_descr(ctypes.byref(desc), prec, domain, - 1, int(fftobj.size)) + + f = mkl_descriptor[fftobj.invec.precision] + f.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_long] + status = f(ctypes.byref(desc), domain, int(fftobj.size)) + check_status(status) # Now we set various things depending on exactly what kind of transform we're # performing.