From 42f7a9273b5f7ed7a91ee758fe38aef3f0f343e7 Mon Sep 17 00:00:00 2001 From: Mohsen Naghipourfar Date: Wed, 21 Aug 2019 14:32:39 +0200 Subject: [PATCH] Update scGen keras network --- scgen/models/_vae_keras.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/scgen/models/_vae_keras.py b/scgen/models/_vae_keras.py index 1553859..4a7a0ec 100644 --- a/scgen/models/_vae_keras.py +++ b/scgen/models/_vae_keras.py @@ -98,7 +98,7 @@ def _encoder(self): log_var = Dense(self.z_dim, kernel_initializer=self.init_w)(h) z = Lambda(self._sample_z, output_shape=(self.z_dim,), name="Z")([mean, log_var]) - self.encoder_model = Model(inputs=self.x, outputs=[mean, log_var, z], name="encoder") + self.encoder_model = Model(inputs=self.x, outputs=z, name="encoder") return mean, log_var def _decoder(self): @@ -178,7 +178,7 @@ def _create_network(self): self.mu, self.log_var = self._encoder() self.x_hat = self._decoder() - self.vae_model = Model(inputs=self.x, outputs=self.decoder_model(self.encoder_model(self.x)[2]), name="VAE") + self.vae_model = Model(inputs=self.x, outputs=self.decoder_model(self.encoder_model(self.x)), name="VAE") def _loss_function(self): """ @@ -224,7 +224,7 @@ def to_latent(self, data): latent: numpy nd-array Returns array containing latent space encoding of 'data' """ - latent = self.encoder_model.predict(data)[2] + latent = self.encoder_model.predict(data) return latent def _avg_vector(self, data): @@ -246,7 +246,7 @@ def _avg_vector(self, data): latent_avg = numpy.average(latent, axis=0) return latent_avg - def reconstruct(self, data, use_data=False): + def reconstruct(self, data): """ Map back the latent space encoding via the decoder. @@ -265,11 +265,6 @@ def reconstruct(self, data, use_data=False): rec_data: 'numpy nd-array' Returns 'numpy nd-array` containing reconstructed 'data' in shape [n_obs, n_vars]. """ - # if use_data: - # latent = data - # else: - # latent = self.to_latent(data) - # rec_data = self.sess.run(self.x_hat, feed_dict={self.z_mean: latent, self.is_training: False}) rec_data = self.decoder_model.predict(x=data) return rec_data @@ -321,7 +316,7 @@ def linear_interpolation(self, source_adata, dest_adata, n_steps): vector = start * (1 - alpha) + end * alpha vectors[i, :] = vector vectors = numpy.array(vectors) - interpolation = self.reconstruct(vectors, use_data=True) + interpolation = self.reconstruct(vectors) return interpolation def predict(self, adata, conditions, cell_type_key, condition_key, adata_to_predict=None, celltype_to_predict=None, obs_key="all"): @@ -393,7 +388,7 @@ def predict(self, adata, conditions, cell_type_key, condition_key, adata_to_pred else: latent_cd = self.to_latent(ctrl_pred.X) stim_pred = delta + latent_cd - predicted_cells = self.reconstruct(stim_pred, use_data=True) + predicted_cells = self.reconstruct(stim_pred) return predicted_cells, delta def restore_model(self):