Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Rust extension #189

Merged
merged 51 commits into from
Jan 12, 2024
Merged

Introduce Rust extension #189

merged 51 commits into from
Jan 12, 2024

Conversation

alecandido
Copy link
Member

@alecandido alecandido commented Jan 9, 2023

Closes #185

The working layout reference is https://github.com/AleCandido/atuin/tree/main/full

@alecandido alecandido changed the base branch from master to develop January 9, 2023 12:39
@alecandido alecandido marked this pull request as draft January 9, 2023 12:46
@felixhekhorn felixhekhorn added refactor Refactor code rust Rust extension related labels Jan 9, 2023
@felixhekhorn
Copy link
Contributor

To have a LO NS proof-of-concept implementation we need

The complexity increases by considering the singlet case XOR higher orders and even more by considering them simultaneously (different solutions). Can we save

def check_gamma_1_pegasus(N, NF):
which is a very strong check?

@alecandido
Copy link
Member Author

Wait a second: this is just a proof of concept, for the time being, at the moment really limited to stabilize the tooling and the interface.

The minimal step for a production product would be to migrate (part of) the splitting functions expressions, but neither the integration nor the interpolation (that is the more ambitious plan, but learning from #172 we should limit our own ambition).

Unfortunately, I'd like to do it incrementally, but I'm not sure I can: is there a way to use a compiled extension callable inside a Numba decorated function?

@felixhekhorn
Copy link
Contributor

Unfortunately, I'd like to do it incrementally, but I'm not sure I can: is there a way to use a compiled extension callable inside a Numba decorated function?

Mmm ... reading here I didn't understand whether numba can read CFFI calls, because here I think they do the other way round: calling numba from C ... this section on numpy might also be interesting ...

@alecandido
Copy link
Member Author

Mmm ... reading here I didn't understand whether numba can read CFFI calls, because here I think they do the other way round: calling numba from C ... this section on numpy might also be interesting ...

All the resources are relevant, especially second and third.
We have the further problem that we are not directly binding to C ourselves. But I would expect it to be complicate anyhow (though worth having a look, not soon).

@alecandido alecandido changed the base branch from develop to master January 27, 2023 14:09
@felixhekhorn
Copy link
Contributor

As a side remark: numba is not the fastest - something we suspected already ...

t(f2py)/t(numba) =~ 0.28 for 1e4 times summing 100 different values of $\psi_0$

check script
import timeit
import sys

import numpy as np
import numba as nb

import cernlib

@nb.njit(cache=True)
def cern_polygamma(Z, K):
    # fmt: off
    DELTA = 5e-13
    R1 = 1
    HF = R1/2
    C1 = np.pi**2
    C2 = 2*np.pi**3
    C3 = 2*np.pi**4
    C4 = 8*np.pi**5

    # SGN is originally indexed 0:4 -> no shift
    SGN = [-1,+1,-1,+1,-1]
    # FCT is originally indexed -1:4 -> shift +1
    FCT = [0,1,1,2,6,24]

    # C is originally indexed 1:6 x 0:4 -> swap indices and shift new last -1
    C = nb.typed.List()
    C.append([
            8.33333333333333333e-2,
           -8.33333333333333333e-3,
            3.96825396825396825e-3,
           -4.16666666666666667e-3,
            7.57575757575757576e-3,
           -2.10927960927960928e-2])
    C.append([
            1.66666666666666667e-1,
           -3.33333333333333333e-2,
            2.38095238095238095e-2,
           -3.33333333333333333e-2,
            7.57575757575757576e-2,
           -2.53113553113553114e-1])
    C.append([
            5.00000000000000000e-1,
           -1.66666666666666667e-1,
            1.66666666666666667e-1,
           -3.00000000000000000e-1,
            8.33333333333333333e-1,
           -3.29047619047619048e+0])
    C.append([
            2.00000000000000000e+0,
           -1.00000000000000000e+0,
            1.33333333333333333e+0,
           -3.00000000000000000e+0,
            1.00000000000000000e+1,
           -4.60666666666666667e+1])
    C.append([10., -7., 12., -33., 130., -691.])
    U=Z
    X=np.real(U)
    A=np.abs(X)
    if K < 0 or K > 4:
        raise NotImplementedError("Order K has to be in [0:4]")
    if np.abs(np.imag(U)) < DELTA and np.abs(X+int(A)) < DELTA:
        raise ValueError("Argument Z equals non-positive integer")
    K1=K+1
    if X < 0:
        U=-U
    V=U
    H=0
    if A < 15:
        H=1/V**K1
        for I in range(1,14-int(A)+1):
            V=V+1
            H=H+1/V**K1
        V=V+1
    R=1/V**2
    P=R*C[K][6-1]
    for I in range(5,1-1,-1):
        P=R*(C[K][I-1]+P)
    H=SGN[K]*(FCT[K+1]*H+(V*(FCT[K-1+1]+P)+HF*FCT[K+1])/V**K1)
    if K == 0:
        H=H+np.log(V)
    if X < 0:
        V=np.pi*U
        X=np.real(V)
        Y=np.imag(V)
        A=np.sin(X)
        B=np.cos(X)
        T=np.tanh(Y)
        P=complex(B,-A*T)/complex(A,B*T)
        if K == 0:
            H=H+1/U+np.pi*P
        elif K == 1:
            H=-H+1/U**2+C1*(P**2+1)
        elif K == 2:
            H=H+2/U**3+C2*P*(P**2+1)
        elif K == 3:
            R=P**2
            H=-H+6/U**4+C3*((3*R+4)*R+1)
        elif K == 4:
            R=P**2
            H=H+24/U**5+C4*P*((3*R+5)*R+2)
    return H
    # fmt: on
    

print("Compile numba")
cern_polygamma(1.,1)

ns = 1. + np.linspace(0,10.,100) * 1j

def a(): return np.sum([cernlib.wpsipg(n,1) for n in ns])
def b(): return np.sum([cern_polygamma(n,1) for n in ns])

num = int(float(sys.argv[1]))

aa = timeit.timeit(a,number=num)
print("f2py",aa)
bb = timeit.timeit(b,number=num)
print("numba",bb)
print("f2py/numba",aa/bb)
f2py3.10 -m cernlib wpsipg.F -c
      FUNCTION WPSIPG(Z,K)             
cf2py intent(in) Z
cf2py intent(in) K                                         
                                                                                
      IMPLICIT DOUBLE PRECISION (A-H,O-Z)                                       
      COMPLEX*16 WPSIPG,Z,U,V,H,R,P,GCMPLX                                      
      CHARACTER NAME*(*)                                                        
      CHARACTER*80 ERRTXT                                                       
      PARAMETER (NAME = 'CPSIPG/WPSIPG')                                        
      DIMENSION C(6,0:4),FCT(-1:4),SGN(0:4)                                     
                                                                                
      PARAMETER (DELTA = 5D-13)                                                 
      PARAMETER (R1 = 1, HF = R1/2)                                             
      PARAMETER (PI = 3.14159 26535 89793 24D0)                                 
      PARAMETER (C1 = PI**2, C2 = 2*PI**3, C3 = 2*PI**4, C4 = 8*PI**5)          
                                                                                
      DATA SGN /-1,+1,-1,+1,-1/, FCT /0,1,1,2,6,24/                             
                                                                                
      DATA C(1,0) / 8.33333 33333 33333 33D-2/                                  
      DATA C(2,0) /-8.33333 33333 33333 33D-3/                                  
      DATA C(3,0) / 3.96825 39682 53968 25D-3/                                  
      DATA C(4,0) /-4.16666 66666 66666 67D-3/                                  
      DATA C(5,0) / 7.57575 75757 57575 76D-3/                                  
      DATA C(6,0) /-2.10927 96092 79609 28D-2/                                  
                                                                                
      DATA C(1,1) / 1.66666 66666 66666 67D-1/                                  
      DATA C(2,1) /-3.33333 33333 33333 33D-2/                                  
      DATA C(3,1) / 2.38095 23809 52380 95D-2/                                  
      DATA C(4,1) /-3.33333 33333 33333 33D-2/                                  
      DATA C(5,1) / 7.57575 75757 57575 76D-2/                                  
      DATA C(6,1) /-2.53113 55311 35531 14D-1/                                  
                                                                                
      DATA C(1,2) / 5.00000 00000 00000 00D-1/                                  
      DATA C(2,2) /-1.66666 66666 66666 67D-1/                                  
      DATA C(3,2) / 1.66666 66666 66666 67D-1/                                  
      DATA C(4,2) /-3.00000 00000 00000 00D-1/                                  
      DATA C(5,2) / 8.33333 33333 33333 33D-1/                                  
      DATA C(6,2) /-3.29047 61904 76190 48D+0/                                  
                                                                                
      DATA C(1,3) / 2.00000 00000 00000 00D+0/                                  
      DATA C(2,3) /-1.00000 00000 00000 00D+0/                                  
      DATA C(3,3) / 1.33333 33333 33333 33D+0/                                  
      DATA C(4,3) /-3.00000 00000 00000 00D+0/                                  
      DATA C(5,3) / 1.00000 00000 00000 00D+1/                                  
      DATA C(6,3) /-4.60666 66666 66666 67D+1/                                  
                                                                                
      DATA (C(I,4),I=1,6) / 10, -7, 12, -33, 130, -691/                         
                                                                                
      GCMPLX(X,Y)=DCMPLX(X,Y)                                                   
                                                                                
      U=Z                                                                       
      X=U                                                                       
      A=ABS(X)                                                                  
      IF(K .LT. 0 .OR. K .GT. 4) THEN                                           
       H=0                                                                      
       WRITE(ERRTXT,101) K                                                      
c       CALL MTLPRT(NAME,'C317.1',ERRTXT)                                        
#if !defined(CERNLIB_GFORTRAN)
      ELSEIF(ABS(IMAG(U)) .LT. DELTA .AND. ABS(X+NINT(A)) .LT. DELTA)           
#else
      ELSEIF(ABS(AIMAG(U)) .LT. DELTA .AND. ABS(X+NINT(A)) .LT. DELTA)           
#endif
     1                                                        THEN              
       H=0                                                                      
       WRITE(ERRTXT,102) X                                                      
c       CALL MTLPRT(NAME,'C317.2',ERRTXT)                                        
      ELSE                                                                      
       K1=K+1                                                                   
       IF(X .LT. 0) U=-U                                                        
       V=U                                                                      
       H=0                                                                      
       IF(A .LT. 15) THEN                                                       
        H=1/V**K1                                                               
        DO 1 I = 1,14-INT(A)                                                    
        V=V+1                                                                   
    1   H=H+1/V**K1                                                             
        V=V+1                                                                   
       END IF                                                                   
       R=1/V**2                                                                 
       P=R*C(6,K)                                                               
       DO 2 I = 5,1,-1                                                          
    2  P=R*(C(I,K)+P)                                                           
       H=SGN(K)*(FCT(K)*H+(V*(FCT(K-1)+P)+HF*FCT(K))/V**K1)                     
       IF(K .EQ. 0) H=H+LOG(V)                                                  
       IF(X .LT. 0) THEN                                                        
        V=PI*U                                                                  
        X=V                                                                     
#if !defined(CERNLIB_GFORTRAN)
        Y=IMAG(V)                                                               
#else
        Y=AIMAG(V)                                                               
#endif
        A=SIN(X)                                                                
        B=COS(X)                                                                
        T=TANH(Y)                                                               
        P=GCMPLX(B,-A*T)/GCMPLX(A,B*T)                                          
        IF(K .EQ. 0) THEN                                                       
         H=H+1/U+PI*P                                                           
        ELSEIF(K .EQ. 1) THEN                                                   
         H=-H+1/U**2+C1*(P**2+1)                                                
        ELSEIF(K .EQ. 2) THEN                                                   
         H=H+2/U**3+C2*P*(P**2+1)                                               
        ELSEIF(K .EQ. 3) THEN                                                   
         R=P**2                                                                 
         H=-H+6/U**4+C3*((3*R+4)*R+1)                                           
        ELSEIF(K .EQ. 4) THEN                                                   
         R=P**2                                                                 
         H=H+24/U**5+C4*P*((3*R+5)*R+2)                                         
        ENDIF                                                                   
       ENDIF                                                                    
      ENDIF                                                                     
      WPSIPG=H                                                                  
      RETURN                                                                    
  101 FORMAT('K = ',I5,'  (< 0  OR  > 4)')                                      
  102 FORMAT('ARGUMENT EQUALS NON-POSITIVE INTEGER = ',F8.1)                    
      END                                                                       

@felixhekhorn
Copy link
Contributor

Not using numba on quad_ker would cost us big time

t(numba)/t(raw) ~ 0.005 for 100 times calling quad_ker for 100 points

check script
import timeit
import sys

import numba as nb
import numpy as np

from eko.evolution_operator import quad_ker
from eko import basis_rotation as br

quad_ker_nb = nb.njit(cache=True)(quad_ker)

args = dict(order=(1, 0),
                mode0=br.non_singlet_pids_map["ns+"],
                mode1=0,
                method="",
                is_log=True,
                logx=0.1,
                areas=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
                as_list=np.array([2.0, 1.0]),
                mu2_from=1.0,
                mu2_to=2.0,
                a_half=np.array([[1.5, 0.01]]),
                alphaem_running=False,
                nf=3,
                L=0,
                ev_op_iterations=1,
                ev_op_max_order=(1, 0),
                sv_mode=1,
                is_threshold=False,
                is_polarized=False,
                is_time_like=False,)

print("Compile numba")
quad_ker_nb(.6,**args)

us = np.linspace(.5, 1., 100)

def a(): return np.sum([quad_ker(u,**args) for u in us])
def b(): return np.sum([quad_ker_nb(u,**args) for u in us])

num = int(float(sys.argv[1]))

bb = timeit.timeit(b,number=num)
print("numba",bb)
aa = timeit.timeit(a,number=num)
print("raw",aa)
print("numba/raw",bb/aa)

together with this eko hack:

diff --git a/src/eko/evolution_operator/__init__.py b/src/eko/evolution_operator/__init__.py
index 335dfa4f..4e99f225 100644
--- a/src/eko/evolution_operator/__init__.py
+++ b/src/eko/evolution_operator/__init__.py
@@ -184,7 +184,7 @@ class QuadKerBase:
         return self.path.prefactor * pj * self.path.jac


-@nb.njit(cache=True)
+#@nb.njit(cache=True)
 def quad_ker(
     u,
     order,

@alecandido
Copy link
Member Author

Is it possible that the class is making everything worse? Or is it just a Python call overhead?

To get more info, we could try to make a function with similar features of quad_ker, and see how much Numba gains.
But it is clear that just commenting the decorator is not an option, even if it had been the easiest of course (if it were working properly).

@alecandido
Copy link
Member Author

Actually @felixhekhorn, we should make the test with NNLO: if we lose a lot of time on LO we don't care so much, since it is already fast enough. The big deal is how much we lose with already expensive computation: if it's negligible overhead, because the rest is much slower, for me it would be fine to spoil LO for a while.

@alecandido
Copy link
Member Author

@felixhekhorn
Copy link
Contributor

felixhekhorn commented Jul 31, 2023

@felixhekhorn
Copy link
Contributor

63eb191 adds a quick and dirty bibtex -> fake-crate parser, which does it job for the moment. We can and will update it the moment we need it.

If you agree @alecandido @giacomomagni this can be merged. (just as a reminder: this does not touch the Python program in any way and if you want to use this new feature you have to apply the provided patch)

src/eko/evolution_operator/quad_ker.py Show resolved Hide resolved
pyproject.toml Outdated Show resolved Hide resolved
crates/make_bib.py Outdated Show resolved Hide resolved
crates/eko/pyproject.toml Show resolved Hide resolved
@alecandido
Copy link
Member Author

Btw, I won't be able to approve, since I opened the PR...

@giacomomagni
Copy link
Collaborator

@felixhekhorn can we merge here, so we can go on with it ?

@felixhekhorn
Copy link
Contributor

@felixhekhorn can we merge here, so we can go on with it ?

the packaging issue with maturin needs to be resolved first ... I'm still tempted by postponing the problem by creating another .patch file as I'm not as familiar with Python packaging as @alecandido 😇 ... as far as I understood we would also need to touch the workflows which would ripple through the whole NNPDF packages ...

@alecandido
Copy link
Member Author

the packaging issue with maturin needs to be resolved first

Correct

as far as I understood we would also need to touch the workflows which would ripple through the whole NNPDF packages ...

Almost correct: the current situation is no good, because the dependencies are specified under the Poetry-specific sections, but the build backend is coming from Maturin, with the effect of forgetting EKO dependencies.

However, the part I was worried about might be fake, since poetry build is only the frontend, and if maturin is specified as backend that one will be used (not poetry-core, that indeed has to be specified separately, even when using the poetry frontend to build).

So, you might not need any change to the workflow at all for deploying, but I believe you'd need to duplicate the dependencies if you want to use the testing workflow in EKO (because it will install the environment with Poetry, which is not aware of the deps if they are not specified in its sections, that of course are not read by Maturin, which is reading the standard ones...).

I'm still tempted by postponing the problem by creating another .patch

Of course, that is an option, with the benefit of diverging less from main as time passes. But the content of this PR is mostly to everything in main, since it mostly happens in crates/ and what is outside that folder is actually mostly the packaging part (or the part in Python, already contained in a patch).
Thus, the benefit is mostly limited to close a PR. Which could also be enough, if you wish.

@felixhekhorn
Copy link
Contributor

felixhekhorn commented Jan 11, 2024

I opted for the lazy option of adding another .patch file and postponing the problem ( any help is appreciated 😇 ) this way we can still close the PR within almost a year 🙃

status:

  • LO unpolarized FFNS QCD LHA is back working
  • we gained some speed
  • opt-in for the new feature by executing rustify.sh (which will patch 3 files)
  • the patches will need to be updated in the future if the interface changes (as is happening in Adding FHMRUVV N3LO splitting functions #335 )
  • the new feature is only working locally, i.e. not in workflows or PyPI - this has to be resolved in a future PR

if you agree @alecandido and @giacomomagni let's merge and push for the rest of ekore

@felixhekhorn felixhekhorn merged commit 5e843b2 into master Jan 12, 2024
6 checks passed
@felixhekhorn felixhekhorn deleted the rust branch January 12, 2024 12:19
@felixhekhorn felixhekhorn mentioned this pull request May 29, 2024
13 tasks
@felixhekhorn
Copy link
Contributor

an attempt using PyO3 is 357cb91

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactor Refactor code rust Rust extension related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Split anomalous dimensions into a Rust crate
3 participants