Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing mesh coordinate after defining problem #43

Open
Simhano opened this issue Nov 18, 2024 · 5 comments
Open

Changing mesh coordinate after defining problem #43

Simhano opened this issue Nov 18, 2024 · 5 comments

Comments

@Simhano
Copy link

Simhano commented Nov 18, 2024

Hi, JAX-FEM community,

I have a question about changing mesh coordinates to optimize geometry.

I tried to get a derivative with respect to the initial configuration.

It is my set_params:

    def set_params(self, params):
        self.fes[0].points = params[0]

Then I defined the problem which is almost similar to the inverse demo in JAX-FEM.
Here is how I set up my differentiation:

   problem = HyperElasticity_opt(mesh, vec=3, dim=3, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info,
                             location_fns=location_fns, internal_pressure=internal_pressure_2)

   original_cood = mesh.points
   original_cood[non_fixed_nodes] = original_cood[non_fixed_nodes] * 1.1

    params = [original_cood]
    
    # Implicit differentiation wrapper
    fwd_pred = ad_wrapper(problem) 
    
    sol_list = fwd_pred(params)
    print(sol_list[0])
    
    def test_fn(sol_list):
        print(sol_list[0])
        return np.sum((sol_list[0] - u_sol_2)**2)
    
    def composed_fn(params):
        # print()
        return test_fn(fwd_pred(params))
    
    d_coord= jax.grad(composed_fn)(params)

My derivative (d_coord) was just zeros, and I could observe that my original coordinate had not changed since the problem was defined. (by observing sol_list[0] in test function) How can I change the original mesh coordinate after defining the problem? If this method does not work, is there any other way to do it?

Thank you for reading!

@tianjuxue
Copy link
Collaborator

Unfortunately JAX-FEM does not support taking derivatives w.r.t. to mesh coordinates. But this type of problem has a very direct workaround. @xwpken Weipeng, could you share the paper that deals with this kind of problem? You need some trick in reformulating your problem (perhaps a smart definition for deformation gradient F in some alternative reference configuration).

@Simhano
Copy link
Author

Simhano commented Nov 18, 2024

Unfortunately JAX-FEM does not support taking derivatives w.r.t. to mesh coordinates. But this type of problem has a very direct workaround. @xwpken Weipeng, could you share the paper that deals with this kind of problem? You need some trick in reformulating your problem (perhaps a smart definition for deformation gradient F in some alternative reference configuration).

Thank you so much!
I really appreciate your quick response! and I would sincerely appreciate it if Weipeng could share the paper that deals with this kind of problem!

Is JAX-FEM cannot take derivatives w.r.t. mesh coordinate because the mesh coordinate cannot be changed after the problem defined?

Thank you so much again!

@BBBBBbruce
Copy link

Hi I tried this before. what I found is that the initial positions are not loaded by jax.numpy but numpy, so if you want to get the autodiff gradients you need to change the loading part to use jax.numpy. I got the gradients and compared with Finite difference method, which is correct.

However, I can only confirm this is correct in terms of coding. not sure if this physically represents what you expected(for example if there is any discontinuous step in the solving pipeline).

@Simhano
Copy link
Author

Simhano commented Nov 19, 2024

Hi @BBBBBbruce ,
Thank you so much for your reply! If it's not too much trouble, could you please share the code showing how you modified JAX-FEM and integrated it into the problem code? I encountered some errors related to JAX NumPy when I tried to make changes, but unfortunately, I wasn't successful. Additionally, I couldn't reproduce the JAX NumPy-related error when I attempted to change the initial coordinates in the problem.

I sincerely appreciate it again!

@BBBBBbruce
Copy link

Hi, my code is a very messy now... I dont remember which parts I adjusted for that. you can share the error you encountered. I will have a look.

also, I think the way you implement set_param is not proper. only update the fe coordinates is not sufficient, for example, you need rerun the codes that computes the dirichlet BC(since your coodinates has been updated, need to see if the same sets of face/nodes have been selected). the way I did is recall the init function in problem. then it should update everything you need.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants