_model_py module

class _model.Decoder(r, r_prime, k, final, alpha, dropout, model_type, correction)

Bases: Module

Decoder neural network module for a GAT-based autoencoder

Parameters:
  • r (int) -- Dimensionality of the input latent space.

  • r_prime (int) -- Dimensionality of the intermediate space.

  • k (List[int]) -- List of the number of categories for each categorical variable. None values are ignored.

  • final (int) -- Number of output features.

  • alpha (float) -- Negative slope for the LeakyReLU activation function.

  • dropout (float) -- Dropout probability for regularization.

  • model_type (str) -- Type of the model ('single', 'full', or 'simple').

  • correction (bool) -- Flag indicating whether to apply correction using categorical variables.

Variables:
  • layer1 (nn.Linear) -- First linear transformation layer.

  • layer2 (nn.Linear) -- Second linear transformation layer.

  • layer3 -- Third linear transformation layer.

  • leaky_relu (nn.LeakyReLU) -- LeakyReLU activation function.

  • dropout1 (nn.Dropout) -- Dropout layer after the first linear transformation.

  • dropout2 (nn.Dropout) -- Dropout layer after the second linear transformation.

  • dropout (nn.Dropout) -- Dropout layer for the 'simple' model type.

decode(Z, ejs, K, k, correction)

Decodes using linear transformations and activations

Parameters:
  • Z (torch.Tensor) -- Encoded latent space representation tensor.

  • ejs (Any) -- Placeholder parameter (not used in the method).

  • K (List[torch.Tensor]) -- List of correction variable tensors.

  • k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.

  • correction (bool) -- Flag indicating whether to apply correction using the correction variables.

Returns:

Decoded output tensor in the original feature space.

Return type:

torch.Tensor

init_weights(m)

Initialize weights of Linear layers using Xavier initialization

Parameters:

m (nn.Module) -- A PyTorch module instance.

Returns:

None

prepare_latent_space_with_korrection_vars(K, k, lat_space)
Parameters:
  • K (List[torch.Tensor]) -- List of correction variable tensors.

  • k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.

  • lat_space (torch.Tensor) -- Latent space representation tensor.

Returns:

A tuple containing: - h: The modified latent space tensor with correction variables incorporated. - total_cat: The sum of the embedded categorical variables, or None if no categorical variables are present.

Return type:

Tuple[torch.Tensor, Optional[torch.Tensor]]

training: bool
class _model.Encoder(in_channels, k, r_prime, r, h, alpha, dropout, model_type, correction)

Bases: Module

Encoder neural network module using Graph Attention Networks (GATs)

Variables:
  • layer1 (GATv2Conv) -- First graph attention layer.

  • layer2 (nn.Linear) -- Second linear transformation layer.

  • layer3 (nn.Linear) -- Third linear transformation layer.

  • mu_layer (nn.Linear) -- Linear layer to compute the mean of the latent space representation.

  • logvar_layer (nn.Linear) -- Linear layer to compute the log variance of the latent space representation.

  • leaky_relu (nn.LeakyReLU) -- LeakyReLU activation function.

encode(X, ejs, K, k, correction)

Perform encoding using graph attention layers

Parameters:
  • X (torch.Tensor) -- Input feature matrix (zeroed counts).

  • ejs (torch.Tensor) -- Edge indices defining the graph structure.

  • K (List[torch.Tensor]) -- List of correction variable tensors.

  • k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.

  • correction (bool) -- Flag indicating whether to apply correction using the correction variables.

Returns:

A tuple containing: - mu: Mean of the encoded input. - logvar: Log variance of the encoded input.

Return type:

Tuple[torch.Tensor, torch.Tensor]

init_weights(m)

Initialize weights of Linear layers

Parameters:

m (nn.Module) -- A PyTorch module instance.

Returns:

None

prepare_latent_space_with_korrection_vars(K, k, lat_space)
Parameters:
  • K (List[torch.Tensor]) -- List of correction variable tensors.

  • k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.

  • lat_space (torch.Tensor) -- Latent space representation tensor.

Returns:

A tuple containing: - h: The modified latent space tensor with correction variables incorporated. - total_cat: The sum of the embedded categorical variables, or None if no categorical variables are present.

Return type:

Tuple[torch.Tensor, Optional[torch.Tensor]]

training: bool
class _model.LibDecoder(r, r_prime, k, final, alpha, correction)

Bases: Module

Library-size decoder module that decodes latent variables back into the original space

Variables:
  • lib_decode_size_factor (nn.Linear) -- Linear layer for decoding the latent variables.

  • lib_decode_size_factor_2 (nn.Linear) -- Linear layer for further decoding the library size.

  • leaky_relu (nn.LeakyReLU) -- Leaky ReLU activation function.

  • embeddings (nn.ModuleList) -- Module list of embedding layers for categorical correction variables.

decode(Z_L, K, k, correction)

Perform the decoding operation for latent variables and correction variables.

Parameters:
  • Z_L (torch.Tensor) -- Latent variables tensor.

  • K (List[torch.Tensor]) -- List of correction variable tensors.

  • k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.

  • correction (bool) -- Flag indicating whether to apply correction using the correction variables.

Returns:

Decoded output tensor representing the reconstructed library size.

Return type:

torch.Tensor

init_weights(m)

Initialize weights of Linear layers using Xavier initialization

Parameters:

m (nn.Module) -- A PyTorch module instance.

Returns:

None

prepare_latent_space_with_korrection_vars(K, k, lat_space)
Parameters:
  • K (List[torch.Tensor]) -- List of correction variable tensors.

  • k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.

  • lat_space (torch.Tensor) -- Latent space representation tensor.

Returns:

A tuple containing: - h: The modified latent space tensor with correction variables incorporated. - total_cat: The sum of the embedded categorical variables, or None if no categorical variables are present.

Return type:

Tuple[torch.Tensor, Optional[torch.Tensor]]

training: bool
class _model.LibEncoder(k, r_prime, r, alpha, dropout, correction)

Bases: Module

Library-size encoder module using a linear transformation layer followed by LeakyReLU activation

Variables:
  • layer_lib1 (nn.Linear) -- Linear layer to encode the combined feature of log library sizes and variables.

  • layer_lib_mu (nn.Linear) -- Linear layer to encode the mean for the library size.

  • layer_lib_logvar (nn.Linear) -- Linear layer to encode the log variance (logvar) for the library size.

  • leaky_relu (nn.LeakyReLU) -- Leaky ReLU activation function.

  • dropout1 (nn.Dropout) -- Dropout layer to prevent overfitting.

  • embeddings (nn.ModuleList) -- Module list of embedding layers for categorical correction variables.

encode(log_lib, K, k, correction)

Perform the encoding operation for log library sizes and variables

Parameters:
  • log_lib (torch.Tensor) -- Tensor of log library sizes.

  • K (List[torch.Tensor]) -- List of correction variable tensors.

  • k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.

  • correction (bool) -- Flag indicating whether to apply correction using the correction variables.

Returns:

A tuple containing: - mu: Tensor representing the mean of the latent space representation. - logvar: Tensor representing the log variance of the latent space representation.

Return type:

Tuple[torch.Tensor, torch.Tensor]

init_weights(m)

Initialize weights of Linear layers using Xavier initialization

Parameters:

m (nn.Module) -- A PyTorch module instance.

Returns:

None

prepare_latent_space_with_korrection_vars(K, k, lat_space)
Parameters:
  • K (List[torch.Tensor]) -- List of correction variable tensors.

  • k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.

  • lat_space (torch.Tensor) -- Latent space representation tensor.

Returns:

A tuple containing: - h: The modified latent space tensor with correction variables incorporated. - total_cat: The sum of the embedded categorical variables, or None if no categorical variables are present.

Return type:

Tuple[torch.Tensor, Optional[torch.Tensor]]

training: bool
class _model.VAE(N, M, ks, log, configs)

Bases: Module

Variational Autoencoder (VAE) to capture the latent structure Prepares encoders/decoders/linear layers of the VAE

Parameters:
  • N (int) -- Number of genes.

  • M (int) -- Number of samples.

  • ks (List[int]) -- List of integers representing the number of categories for each correction variable.

  • log (logging.Logger) -- Logger for outputting information during model operations.

  • configs (dict) -- Configuration dictionary containing parameters for the VAE.

batch_centroid_loss(counts, Ks)

Computes a loss based on the Euclidean distance between centroids of each experimental batch in the latent space.

Parameters:
  • counts (List[torch.Tensor]) -- List of tensors containing the latent representations per batch. Each element corresponds to a different tissue or condition and has shape [num_samples, latent_dim].

  • Ks (List[torch.Tensor]) -- List of tensors where the first column indicates batch membership for each sample in counts. Each tensor corresponds to a different tissue or condition and has shape [num_samples, num_batches].

Returns:

A tensor containing the mean of the upper triangular non-zero Euclidean distances between batch centroids for each tissue or condition. Each element in the tensor corresponds to the computed distance for one of the tissues or conditions.

Return type:

torch.Tensor

decode(latent_vars, Ks, ks, edges)

Decodes the latent variables and combines them with correction variables to calculate ZINB parameters for each type of decoder configuration.

Parameters:
  • latent_vars (Dict[str, List[torch.Tensor]]) -- Dictionary containing reparameterized latent spaces computed by the encode_reparameterization() method. It includes keys like 'Z_Ls' for library sizes and 'Z_As' for sample-sample interactions.

  • Ks (List[torch.Tensor]) -- List of one-hot encoded matrices corresponding to batch or other categorical variables, one for each mini-batch.

  • ks (List[int]) -- List of integers representing the number of categories for each correction variable.

  • edges (List[torch.Tensor]) -- List of adjacency matrices or edge lists representing sample-sample interactions, one for each mini-batch.

Returns:

A dictionary containing: - DLs: List of decoded library sizes. - DAs: List of decoded sample-sample interactions. - distributional_parameters: Dictionary containing the ZINB distributional parameters (pi, omega, and theta).

Return type:

Dict[str, Union[List[torch.Tensor], Dict[str, torch.Tensor]]]

encode_reparameterization(Xs, Ks, ks, edges)

Encodes and reparameterizes the input data to produce latent structures for the library size (Z_L) and network structure of sample-sample interactions (Z_A).

Parameters:
  • Xs (List[torch.Tensor]) -- List of zeroed expression matrices, one for each mini-batch.

  • Ks (List[torch.Tensor]) -- List of one-hot encoded matrices corresponding to batch or other categorical variables, one for each mini-batch.

  • ks (List[int]) -- List of integers representing the number of categories for each correction variable.

  • edges (List[torch.Tensor]) -- List of adjacency matrices or edge lists representing sample-sample interactions, one for each mini-batch.

Returns:

A dictionary containing two sub-dictionaries: - latent_spaces: Contains the mean (mu) and log variance (logvar) for the library size, network interactions, expression matrix, and simple models for each mini-batch. - latent_variables: Contains the reparameterized latent variables Z_L, Z_A, Z_X, and Z_simple for each mini-batch.

Return type:

Dict[str, Dict[str, List[torch.Tensor]]]

forward(batch)

Processes a batch of data through the VAE, performing encoding, reparameterization, and decoding steps to generate the outputs used for model training or inference.

Parameters:

batch (Dict[str, List[torch.Tensor]]) --

A dictionary containing tensors that represent different parts of the data batch. Expected keys are:

  • 'X_batches': Zeroed expression matrices of the minibatch.

  • 'R_batches': Raw expression matrices of the minibatch.

  • 'K_batches': Correction variables.

  • 'k_batches': Levels of correction variables.

  • 'idx_batches': Indices of samples in the minibatch.

  • 'ej_batches': Graph edges in each minibatch (used if the model includes graph data).

Returns:

A dictionary containing various outputs from the forward pass of the model, including:

  • 'latent_spaces': The latent spaces derived from the encoder.

  • 'latent_variables': Reparameterized latent variables.

  • 'X_hat': Predicted data samples (e.g., reconstructed expression levels).

  • 'DAs': Decoded activations from the model.

  • 'DLs': Decoded library size factors.

  • 'lib_size_factors': Library size factors computed post-decoding.

  • 'px_dispersion': Dispersion parameters of the distribution.

  • 'px_omega': Mu parameters of the distribution.

  • 'distributional_parameters': Parameters such as pi, omega, theta used in the distribution.

Return type:

Dict[str, Union[torch.distributions.Distribution, List[torch.Tensor], Dict[str, torch.Tensor]]]

init_weights_vae(m)

Initialize weights of Linear layers using Xavier initialization

Parameters:

m (nn.Module) -- A PyTorch module instance.

Returns:

None

load_pretrained_model()

Loads a pre-trained model state into this model instance. It loads the model's state dictionary, updates the current model instance's parameters, and sets the model to evaluation mode.

Raises:

Exception -- If there are any issues accessing the folder or loading the model file.

loss(Rs_batch, idx_batch, adj_batch, dataset, prefix, epoch, generative_outputs, losses, log)

Calculates and records various losses during training or validation.

Parameters:
  • Rs_batch (List[torch.Tensor]) -- Raw expression matrices for the current minibatch, where each tensor corresponds to a batch from a specific condition or tissue.

  • idx_batch (List[int]) -- List of indices corresponding to samples in the current minibatch.

  • adj_batch (List[torch.Tensor]) -- Adjacency matrices for samples in the minibatch, applicable for models considering sample-sample interactions.

  • dataset (FFPE_Dataset) -- Dataset object providing access to dataset properties and helper methods.

  • prefix (str) -- Indicates the phase of the model ('train' or 'val') during which the loss is being computed.

  • epoch (int) -- The current epoch number in the training/validation process.

  • generative_outputs (Dict[str, Any]) -- Outputs from the forward pass of the VAE model including latent variables and other intermediate data.

  • losses (Dict[str, float]) -- Dictionary to record and update the computed losses over training epochs.

  • log (Logger) -- Logger object for logging the computed losses.

Returns:

The average loss computed across different metrics for the current minibatch.

Return type:

torch.Tensor

remove_ghost_samples(adj_Rs_batch, idx_batch, dataset, adjusted_generative_outputs)

Removes ghost samples from tensors, distributions and lists of vectors and matrices

Parameters:
  • adj_Rs_batch (torch.Tensor) -- Expression tensor of the current minibatch of samples.

  • idx_batch (List[int]) -- List of indices of samples in the minibatch.

  • dataset (Dataset) -- Dataset object created in _data_loader.py.

  • adjusted_generative_outputs (Dict[str, Union[torch.distributions.Distribution, List[torch.Tensor], Dict[str, torch.Tensor]]]) -- Output dictionary from the VAE containing tensors, distributions, and lists of vectors and matrices.

Returns:

None

reparameterize(mu, logvar)

Reparameterization method to sample from a Gaussian distribution

Parameters:
  • mu (torch.Tensor) -- Mean of the Gaussian distribution.

  • logvar (torch.Tensor) -- Natural log of the variance of the Gaussian distribution.

Returns:

Sampled latent variable z.

Return type:

torch.Tensor

set_parameter_requires_grad()

Sets the requires_grad to enable/disable the training of specific layers.

Usage:

This method is typically called after model initialization or loading a pre-trained model to prepare the model for fine-tuning or full training, depending on the experiment's requirements.

training: bool