Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Critic marked as non-trainable #1

Open
aecorredor opened this issue Jul 8, 2023 · 2 comments
Open

Critic marked as non-trainable #1

aecorredor opened this issue Jul 8, 2023 · 2 comments

Comments

@aecorredor
Copy link

Cross posting from your awesome Medium article here just for more awareness.

Maybe I'm mistaken, but in the code the critic is set as non-trainable with critic.trainable = false. I've been trying to port this to TensorFlow.js, and when running the training for the first time, I'm getting an error from the critic optimizer:

Error: variableGrads() expects at least one of the input variables to be trainable, 
but none of the 16 variables is trainable.

If I just remove the line that sets it as non-trainable, the model starts training fine. I re-looked at the tutorial you linked to in your article, and he seems to suggest that the critic is pre-trained, and that's why it's marked as non-trainable. But if we're doing this from scratch, we should just set it to trainable the first time around, correct? Since we actually train the critic every epoch, we even have the discriminatorUpdateMultiplier variable to determine how many times to update the critic.

What am I missing?

@jakespracher
Copy link
Owner

Interesting, thanks for doing that! Feel free to open a PR if you'd like to contribute this back.

I'll have to look more closely. Looking at the comments:

(this builds gan_model)

# define the combined generator and discriminator model, for updating the
# generator
def define_gan(g_model, d_model, config, learning_rate: float = 0.00005):
    # make weights in the discriminator not trainable
    d_model.trainable = False

(this runs training and makes use of c_model, g_model, and gan_model)

def _run_epoch(c_model, config, g_model, gan_model, half_batch, labels, vectors):
    c_real_losses, c_fake_losses = [], []
    for _ in range(config.discriminator_update_multiplier):
        real_loss = _update_critic_on_real_samples(
            c_model, config, half_batch, labels, vectors
        )
        c_real_losses.append(real_loss)
        fake_loss = _update_critic_on_fake_samples(c_model, config, g_model, half_batch)
        c_fake_losses.append(fake_loss)

    g_loss = _update_generator(config, gan_model, half_batch)
    return np.mean(c_real_losses), np.mean(c_fake_losses), g_loss

My gut check is that only gan_model has the critic weights fixed because this model is only used for updating the generator and c_model is updated independently. But, I haven't worked on this project in several years so I could definitely be mistaken.

That said, if the critic weights were truly fixed, I would expect the model to be totally incapable of learning anything which doesn't seem to be the case.

@aecorredor
Copy link
Author

Woah, really appreciate you answering so quickly. And yeah, what you said makes sense to me. This is pretty much the first ML thing I try, so I'm trying to make sense out of the details. Your last statement is what I thought too. I think I just need to figure out how to correctly freeze the critic only for the gan model. I think with what I'm doing right now I'm basically freezing both the critic model that gets used within the gan, and the critic model itself. So I sense some kind of unwanted side effect here. I'll dig around and post any findings here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants