Skip to content

Commit

Permalink
Update scGen keras network
Browse files Browse the repository at this point in the history
  • Loading branch information
Mohsen Naghipourfar committed Aug 21, 2019
1 parent 22dcc45 commit 42f7a92
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions scgen/models/_vae_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _encoder(self):
log_var = Dense(self.z_dim, kernel_initializer=self.init_w)(h)
z = Lambda(self._sample_z, output_shape=(self.z_dim,), name="Z")([mean, log_var])

self.encoder_model = Model(inputs=self.x, outputs=[mean, log_var, z], name="encoder")
self.encoder_model = Model(inputs=self.x, outputs=z, name="encoder")
return mean, log_var

def _decoder(self):
Expand Down Expand Up @@ -178,7 +178,7 @@ def _create_network(self):
self.mu, self.log_var = self._encoder()

self.x_hat = self._decoder()
self.vae_model = Model(inputs=self.x, outputs=self.decoder_model(self.encoder_model(self.x)[2]), name="VAE")
self.vae_model = Model(inputs=self.x, outputs=self.decoder_model(self.encoder_model(self.x)), name="VAE")

def _loss_function(self):
"""
Expand Down Expand Up @@ -224,7 +224,7 @@ def to_latent(self, data):
latent: numpy nd-array
Returns array containing latent space encoding of 'data'
"""
latent = self.encoder_model.predict(data)[2]
latent = self.encoder_model.predict(data)
return latent

def _avg_vector(self, data):
Expand All @@ -246,7 +246,7 @@ def _avg_vector(self, data):
latent_avg = numpy.average(latent, axis=0)
return latent_avg

def reconstruct(self, data, use_data=False):
def reconstruct(self, data):
"""
Map back the latent space encoding via the decoder.
Expand All @@ -265,11 +265,6 @@ def reconstruct(self, data, use_data=False):
rec_data: 'numpy nd-array'
Returns 'numpy nd-array` containing reconstructed 'data' in shape [n_obs, n_vars].
"""
# if use_data:
# latent = data
# else:
# latent = self.to_latent(data)
# rec_data = self.sess.run(self.x_hat, feed_dict={self.z_mean: latent, self.is_training: False})
rec_data = self.decoder_model.predict(x=data)
return rec_data

Expand Down Expand Up @@ -321,7 +316,7 @@ def linear_interpolation(self, source_adata, dest_adata, n_steps):
vector = start * (1 - alpha) + end * alpha
vectors[i, :] = vector
vectors = numpy.array(vectors)
interpolation = self.reconstruct(vectors, use_data=True)
interpolation = self.reconstruct(vectors)
return interpolation

def predict(self, adata, conditions, cell_type_key, condition_key, adata_to_predict=None, celltype_to_predict=None, obs_key="all"):
Expand Down Expand Up @@ -393,7 +388,7 @@ def predict(self, adata, conditions, cell_type_key, condition_key, adata_to_pred
else:
latent_cd = self.to_latent(ctrl_pred.X)
stim_pred = delta + latent_cd
predicted_cells = self.reconstruct(stim_pred, use_data=True)
predicted_cells = self.reconstruct(stim_pred)
return predicted_cells, delta

def restore_model(self):
Expand Down

0 comments on commit 42f7a92

Please sign in to comment.