_model_py module
- class _model.Decoder(r, r_prime, k, final, alpha, dropout, model_type, correction)
Bases:
Module
Decoder neural network module for a GAT-based autoencoder
- Parameters:
r (int) -- Dimensionality of the input latent space.
r_prime (int) -- Dimensionality of the intermediate space.
k (List[int]) -- List of the number of categories for each categorical variable. None values are ignored.
final (int) -- Number of output features.
alpha (float) -- Negative slope for the LeakyReLU activation function.
dropout (float) -- Dropout probability for regularization.
model_type (str) -- Type of the model ('single', 'full', or 'simple').
correction (bool) -- Flag indicating whether to apply correction using categorical variables.
- Variables:
layer1 (nn.Linear) -- First linear transformation layer.
layer2 (nn.Linear) -- Second linear transformation layer.
layer3 -- Third linear transformation layer.
leaky_relu (nn.LeakyReLU) -- LeakyReLU activation function.
dropout1 (nn.Dropout) -- Dropout layer after the first linear transformation.
dropout2 (nn.Dropout) -- Dropout layer after the second linear transformation.
dropout (nn.Dropout) -- Dropout layer for the 'simple' model type.
- decode(Z, ejs, K, k, correction)
Decodes using linear transformations and activations
- Parameters:
Z (torch.Tensor) -- Encoded latent space representation tensor.
ejs (Any) -- Placeholder parameter (not used in the method).
K (List[torch.Tensor]) -- List of correction variable tensors.
k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.
correction (bool) -- Flag indicating whether to apply correction using the correction variables.
- Returns:
Decoded output tensor in the original feature space.
- Return type:
torch.Tensor
- init_weights(m)
Initialize weights of Linear layers using Xavier initialization
- Parameters:
m (nn.Module) -- A PyTorch module instance.
- Returns:
None
- prepare_latent_space_with_korrection_vars(K, k, lat_space)
- Parameters:
K (List[torch.Tensor]) -- List of correction variable tensors.
k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.
lat_space (torch.Tensor) -- Latent space representation tensor.
- Returns:
A tuple containing: - h: The modified latent space tensor with correction variables incorporated. - total_cat: The sum of the embedded categorical variables, or None if no categorical variables are present.
- Return type:
Tuple[torch.Tensor, Optional[torch.Tensor]]
- training: bool
- class _model.Encoder(in_channels, k, r_prime, r, h, alpha, dropout, model_type, correction)
Bases:
Module
Encoder neural network module using Graph Attention Networks (GATs)
- Variables:
layer1 (GATv2Conv) -- First graph attention layer.
layer2 (nn.Linear) -- Second linear transformation layer.
layer3 (nn.Linear) -- Third linear transformation layer.
mu_layer (nn.Linear) -- Linear layer to compute the mean of the latent space representation.
logvar_layer (nn.Linear) -- Linear layer to compute the log variance of the latent space representation.
leaky_relu (nn.LeakyReLU) -- LeakyReLU activation function.
- encode(X, ejs, K, k, correction)
Perform encoding using graph attention layers
- Parameters:
X (torch.Tensor) -- Input feature matrix (zeroed counts).
ejs (torch.Tensor) -- Edge indices defining the graph structure.
K (List[torch.Tensor]) -- List of correction variable tensors.
k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.
correction (bool) -- Flag indicating whether to apply correction using the correction variables.
- Returns:
A tuple containing: - mu: Mean of the encoded input. - logvar: Log variance of the encoded input.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- init_weights(m)
Initialize weights of Linear layers
- Parameters:
m (nn.Module) -- A PyTorch module instance.
- Returns:
None
- prepare_latent_space_with_korrection_vars(K, k, lat_space)
- Parameters:
K (List[torch.Tensor]) -- List of correction variable tensors.
k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.
lat_space (torch.Tensor) -- Latent space representation tensor.
- Returns:
A tuple containing: - h: The modified latent space tensor with correction variables incorporated. - total_cat: The sum of the embedded categorical variables, or None if no categorical variables are present.
- Return type:
Tuple[torch.Tensor, Optional[torch.Tensor]]
- training: bool
- class _model.LibDecoder(r, r_prime, k, final, alpha, correction)
Bases:
Module
Library-size decoder module that decodes latent variables back into the original space
- Variables:
lib_decode_size_factor (nn.Linear) -- Linear layer for decoding the latent variables.
lib_decode_size_factor_2 (nn.Linear) -- Linear layer for further decoding the library size.
leaky_relu (nn.LeakyReLU) -- Leaky ReLU activation function.
embeddings (nn.ModuleList) -- Module list of embedding layers for categorical correction variables.
- decode(Z_L, K, k, correction)
Perform the decoding operation for latent variables and correction variables.
- Parameters:
Z_L (torch.Tensor) -- Latent variables tensor.
K (List[torch.Tensor]) -- List of correction variable tensors.
k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.
correction (bool) -- Flag indicating whether to apply correction using the correction variables.
- Returns:
Decoded output tensor representing the reconstructed library size.
- Return type:
torch.Tensor
- init_weights(m)
Initialize weights of Linear layers using Xavier initialization
- Parameters:
m (nn.Module) -- A PyTorch module instance.
- Returns:
None
- prepare_latent_space_with_korrection_vars(K, k, lat_space)
- Parameters:
K (List[torch.Tensor]) -- List of correction variable tensors.
k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.
lat_space (torch.Tensor) -- Latent space representation tensor.
- Returns:
A tuple containing: - h: The modified latent space tensor with correction variables incorporated. - total_cat: The sum of the embedded categorical variables, or None if no categorical variables are present.
- Return type:
Tuple[torch.Tensor, Optional[torch.Tensor]]
- training: bool
- class _model.LibEncoder(k, r_prime, r, alpha, dropout, correction)
Bases:
Module
Library-size encoder module using a linear transformation layer followed by LeakyReLU activation
- Variables:
layer_lib1 (nn.Linear) -- Linear layer to encode the combined feature of log library sizes and variables.
layer_lib_mu (nn.Linear) -- Linear layer to encode the mean for the library size.
layer_lib_logvar (nn.Linear) -- Linear layer to encode the log variance (logvar) for the library size.
leaky_relu (nn.LeakyReLU) -- Leaky ReLU activation function.
dropout1 (nn.Dropout) -- Dropout layer to prevent overfitting.
embeddings (nn.ModuleList) -- Module list of embedding layers for categorical correction variables.
- encode(log_lib, K, k, correction)
Perform the encoding operation for log library sizes and variables
- Parameters:
log_lib (torch.Tensor) -- Tensor of log library sizes.
K (List[torch.Tensor]) -- List of correction variable tensors.
k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.
correction (bool) -- Flag indicating whether to apply correction using the correction variables.
- Returns:
A tuple containing: - mu: Tensor representing the mean of the latent space representation. - logvar: Tensor representing the log variance of the latent space representation.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- init_weights(m)
Initialize weights of Linear layers using Xavier initialization
- Parameters:
m (nn.Module) -- A PyTorch module instance.
- Returns:
None
- prepare_latent_space_with_korrection_vars(K, k, lat_space)
- Parameters:
K (List[torch.Tensor]) -- List of correction variable tensors.
k (List[int]) -- List of integers representing the number of categories for each correction variable. None values indicate continuous variables.
lat_space (torch.Tensor) -- Latent space representation tensor.
- Returns:
A tuple containing: - h: The modified latent space tensor with correction variables incorporated. - total_cat: The sum of the embedded categorical variables, or None if no categorical variables are present.
- Return type:
Tuple[torch.Tensor, Optional[torch.Tensor]]
- training: bool
- class _model.VAE(N, M, ks, log, configs)
Bases:
Module
Variational Autoencoder (VAE) to capture the latent structure Prepares encoders/decoders/linear layers of the VAE
- Parameters:
N (int) -- Number of genes.
M (int) -- Number of samples.
ks (List[int]) -- List of integers representing the number of categories for each correction variable.
log (logging.Logger) -- Logger for outputting information during model operations.
configs (dict) -- Configuration dictionary containing parameters for the VAE.
- batch_centroid_loss(counts, Ks)
Computes a loss based on the Euclidean distance between centroids of each experimental batch in the latent space.
- Parameters:
counts (List[torch.Tensor]) -- List of tensors containing the latent representations per batch. Each element corresponds to a different tissue or condition and has shape [num_samples, latent_dim].
Ks (List[torch.Tensor]) -- List of tensors where the first column indicates batch membership for each sample in counts. Each tensor corresponds to a different tissue or condition and has shape [num_samples, num_batches].
- Returns:
A tensor containing the mean of the upper triangular non-zero Euclidean distances between batch centroids for each tissue or condition. Each element in the tensor corresponds to the computed distance for one of the tissues or conditions.
- Return type:
torch.Tensor
- decode(latent_vars, Ks, ks, edges)
Decodes the latent variables and combines them with correction variables to calculate ZINB parameters for each type of decoder configuration.
- Parameters:
latent_vars (Dict[str, List[torch.Tensor]]) -- Dictionary containing reparameterized latent spaces computed by the encode_reparameterization() method. It includes keys like 'Z_Ls' for library sizes and 'Z_As' for sample-sample interactions.
Ks (List[torch.Tensor]) -- List of one-hot encoded matrices corresponding to batch or other categorical variables, one for each mini-batch.
ks (List[int]) -- List of integers representing the number of categories for each correction variable.
edges (List[torch.Tensor]) -- List of adjacency matrices or edge lists representing sample-sample interactions, one for each mini-batch.
- Returns:
A dictionary containing: - DLs: List of decoded library sizes. - DAs: List of decoded sample-sample interactions. - distributional_parameters: Dictionary containing the ZINB distributional parameters (pi, omega, and theta).
- Return type:
Dict[str, Union[List[torch.Tensor], Dict[str, torch.Tensor]]]
- encode_reparameterization(Xs, Ks, ks, edges)
Encodes and reparameterizes the input data to produce latent structures for the library size (Z_L) and network structure of sample-sample interactions (Z_A).
- Parameters:
Xs (List[torch.Tensor]) -- List of zeroed expression matrices, one for each mini-batch.
Ks (List[torch.Tensor]) -- List of one-hot encoded matrices corresponding to batch or other categorical variables, one for each mini-batch.
ks (List[int]) -- List of integers representing the number of categories for each correction variable.
edges (List[torch.Tensor]) -- List of adjacency matrices or edge lists representing sample-sample interactions, one for each mini-batch.
- Returns:
A dictionary containing two sub-dictionaries: - latent_spaces: Contains the mean (mu) and log variance (logvar) for the library size, network interactions, expression matrix, and simple models for each mini-batch. - latent_variables: Contains the reparameterized latent variables Z_L, Z_A, Z_X, and Z_simple for each mini-batch.
- Return type:
Dict[str, Dict[str, List[torch.Tensor]]]
- forward(batch)
Processes a batch of data through the VAE, performing encoding, reparameterization, and decoding steps to generate the outputs used for model training or inference.
- Parameters:
batch (Dict[str, List[torch.Tensor]]) --
A dictionary containing tensors that represent different parts of the data batch. Expected keys are:
'X_batches': Zeroed expression matrices of the minibatch.
'R_batches': Raw expression matrices of the minibatch.
'K_batches': Correction variables.
'k_batches': Levels of correction variables.
'idx_batches': Indices of samples in the minibatch.
'ej_batches': Graph edges in each minibatch (used if the model includes graph data).
- Returns:
A dictionary containing various outputs from the forward pass of the model, including:
'latent_spaces': The latent spaces derived from the encoder.
'latent_variables': Reparameterized latent variables.
'X_hat': Predicted data samples (e.g., reconstructed expression levels).
'DAs': Decoded activations from the model.
'DLs': Decoded library size factors.
'lib_size_factors': Library size factors computed post-decoding.
'px_dispersion': Dispersion parameters of the distribution.
'px_omega': Mu parameters of the distribution.
'distributional_parameters': Parameters such as pi, omega, theta used in the distribution.
- Return type:
Dict[str, Union[torch.distributions.Distribution, List[torch.Tensor], Dict[str, torch.Tensor]]]
- init_weights_vae(m)
Initialize weights of Linear layers using Xavier initialization
- Parameters:
m (nn.Module) -- A PyTorch module instance.
- Returns:
None
- load_pretrained_model()
Loads a pre-trained model state into this model instance. It loads the model's state dictionary, updates the current model instance's parameters, and sets the model to evaluation mode.
- Raises:
Exception -- If there are any issues accessing the folder or loading the model file.
- loss(Rs_batch, idx_batch, adj_batch, dataset, prefix, epoch, generative_outputs, losses, log)
Calculates and records various losses during training or validation.
- Parameters:
Rs_batch (List[torch.Tensor]) -- Raw expression matrices for the current minibatch, where each tensor corresponds to a batch from a specific condition or tissue.
idx_batch (List[int]) -- List of indices corresponding to samples in the current minibatch.
adj_batch (List[torch.Tensor]) -- Adjacency matrices for samples in the minibatch, applicable for models considering sample-sample interactions.
dataset (FFPE_Dataset) -- Dataset object providing access to dataset properties and helper methods.
prefix (str) -- Indicates the phase of the model ('train' or 'val') during which the loss is being computed.
epoch (int) -- The current epoch number in the training/validation process.
generative_outputs (Dict[str, Any]) -- Outputs from the forward pass of the VAE model including latent variables and other intermediate data.
losses (Dict[str, float]) -- Dictionary to record and update the computed losses over training epochs.
log (Logger) -- Logger object for logging the computed losses.
- Returns:
The average loss computed across different metrics for the current minibatch.
- Return type:
torch.Tensor
- remove_ghost_samples(adj_Rs_batch, idx_batch, dataset, adjusted_generative_outputs)
Removes ghost samples from tensors, distributions and lists of vectors and matrices
- Parameters:
adj_Rs_batch (torch.Tensor) -- Expression tensor of the current minibatch of samples.
idx_batch (List[int]) -- List of indices of samples in the minibatch.
dataset (Dataset) -- Dataset object created in _data_loader.py.
adjusted_generative_outputs (Dict[str, Union[torch.distributions.Distribution, List[torch.Tensor], Dict[str, torch.Tensor]]]) -- Output dictionary from the VAE containing tensors, distributions, and lists of vectors and matrices.
- Returns:
None
- reparameterize(mu, logvar)
Reparameterization method to sample from a Gaussian distribution
- Parameters:
mu (torch.Tensor) -- Mean of the Gaussian distribution.
logvar (torch.Tensor) -- Natural log of the variance of the Gaussian distribution.
- Returns:
Sampled latent variable z.
- Return type:
torch.Tensor
- set_parameter_requires_grad()
Sets the requires_grad to enable/disable the training of specific layers.
- Usage:
This method is typically called after model initialization or loading a pre-trained model to prepare the model for fine-tuning or full training, depending on the experiment's requirements.
- training: bool