_inference.py module
- class _inference.Inference(preffect_obj=None, task='inference', inference_key=None, inference_overwrite=True, configs=None)
Bases:
object
- calculate_imputation_error(error_type='mse')
- concatenate_1d_minibatch(L, full_idx)
Concatenates sublists of 1D tensors or lists from multiple mini-batches into a single list.
- Parameters:
L (List[List[Union[torch.Tensor, List]]]) -- A nested list where each inner list contains tensors or lists from a specific mini-batch.
full_idx (List[int]) -- Indices representing the order or selection of elements to retain from the concatenated lists, adjusted for any padding that might have been applied during batching.
- Returns:
A list where each sublist corresponds to a concatenated and trimmed set of data across all mini-batches.
- Return type:
List[List]
- concatenate_2d_minibatch(L)
Concatenates sublists of tensors from multiple mini-batches into a single list of tensors.
- Parameters:
L (List[List[torch.Tensor]]) -- A list containing sublists, where each sublist consists of tensors from a particular mini-batch.
- Returns:
A list of concatenated arrays. Each array in the list corresponds to a type of data across all mini-batches (e.g., all latent variable tensors concatenated into a single array).
- Return type:
List[numpy.ndarray]
- impute_values()
Impute missing values in the input data using the trained model.
This method uses the simplest form of imputation by returning the reconstructed counts as an AnnData object. The imputed values are obtained from the model's output after running inference on the input data.
- Returns:
An AnnData object containing the imputed counts.
- Return type:
AnnData
- Raises:
PreffectError -- If the inference has not been computed before calling this method.
- reconstruct_from_minibatches(data, idx_batches)
Reconstructs full dataset arrays from minibatches, combining the various outputs from the model's minibatch processing into unified structures.
- Parameters:
data (List[Dict]) -- List of dictionaries containing outputs from minibatches. Each dictionary should have keys corresponding to different aspects of the model's output, such as 'latent_variables', 'lib_size_factors', 'X_hat', etc.
idx_batches (List[List[torch.Tensor]]) -- List of tensors representing the indices of each minibatch within the overall dataset.
- Returns:
A dictionary containing the reconstructed full dataset. The keys in this dictionary correspond to the unified arrays of outputs such as 'Z_Ls', 'lib_size_factors', 'X_hat_mu', 'X_hat_theta', 'Z_As', and potentially others depending on the model configuration and type.
- Return type:
Dict
- register_inference_run()
Register the current inference instance within the parent object's inference dictionary.
This method checks if an inference instance with the same name already exists in the parent's inference dictionary. If it does not exist or if the overwrite permission is granted, the current inference instance is added to the dictionary using the specified inference key. If an instance with the same name already exists and overwrite permission is not granted, a PreffectError is raised.
- Raises:
PreffectError -- If an inference object with the same name already exists in the parent's dictionary and overwrite permission is set to False.
Note
The parent object is assumed to be a Preffect object that has an inference_dict attribute.
The configs_inf attribute of the current instance is used to determine the inference key
and overwrite permission. - The current inference instance is copied using copy.copy() before being added to the parent's dictionary to avoid unintended modifications.
- return_counts_as_anndata()
Converts the raw and inferred gene expression data into an AnnData format for further analysis.
- Returns:
A list of AnnData objects, where each object represents a tissue or condition from the dataset.
- Return type:
List[anndata.AnnData]
- return_latent_space_as_anndata()
Creates an AnnData object from the latent variables (Z_L, Z_A, Z_Simple) depending on the model type.
- Returns:
A list of AnnData objects, where each object represents the latent space for a specific tissue or condition.
- Return type:
List[anndata.AnnData]
- Raises:
PreffectError -- If the model type is neither 'simple', 'single', nor 'full', an error is raised indicating that the method is not implemented for the given model type.
- run_inference()
Executes the inference process using the model configured in the parent instance. It handles different types of models including 'full', 'single', and 'simple'. It calculates and aggregates results from the model's output across all mini-batches.
- Args:
None
- save(results)
Saves the inference results to a file.
- Parameters:
results (Dict) -- The dictionary containing the inference results to be saved.
- save_visualization(vlib, filename=None)
Saves a matplotlib figure to a specified directory as a PDF file.
- Parameters:
vlib (matplotlib.figure.Figure) -- The matplotlib figure to be saved.
filename (str) -- The name of the file to save the figure as, including the file extension.
- Raises:
PreffectError -- If filename is not provided.
- visualize_batch_adjustment(infer_obj)
Generates side-by-side scatter plots comparing expected and fitted expression averages (per gene) between sample batches. The purpose is to see if batch correction was effective during training. If >2 batches are provided, then all batches are compared to each other (e.g. batches 0+1, 0+2, and 1+2). Since these comparisons will increase quadraticly, we limit comparisons to the first 5 batches.
- Parameters:
infer_obj (Inference) -- An instance of the Inference class containing necessary data and configurations. This object should have access to training datasets for original expressions and inference results to fetch inferred library sizes.
- Returns:
A matplotlib figure containing the scatter plots, or None if there is only one batch or if batch correction is not being performed.
- Return type:
Optional[matplotlib.figure.Figure]
- visualize_fraction_pi_per_gene(infer_obj)
Generates scatter plots comparing expected fraction of zeroes and the mean Pi parameter per gene. This is to compare the number of zeroes present before and after going through a ZINB model.
- Args:
- pref_obj (Preffect): An instance of the Preffect class containing necessary data and configurations.
This object should have access to training datasets for original expressions and inference results to fetch inferred library sizes.
- Returns:
matplotlib.pyplot: A matplotlib figure containing a scatter plot.
- visualize_gene_scatterplot(infer_obj)
Generates scatter plots comparing expected and fitted expression of the first 50 genes of the endogenous set.
- Parameters:
infer_obj (Inference) -- An instance of the Inference class containing necessary data and configurations. This object should have access to training datasets for original expressions and inference results to fetch inferred library sizes.
- Returns:
A matplotlib figure containing the scatter plots for both the expected and inferred read counts of the first 50 genes in the dataset.
- Return type:
matplotlib.figure.Figure
- visualize_latent_recons_umap(infer_obj, my_cmap=None)
UMAP of latent space; assumes 'batch' is in obs table of AnnData
- visualize_lib_size(infer_obj)
Generates histograms comparing expected and inferred library sizes from the given Preffect object.
- Parameters:
infer_obj (Inference) -- An instance of the Inference class containing necessary data and configurations. This object should have access to training datasets for original expressions and inference results to fetch inferred library sizes.
- Returns:
A matplotlib figure containing the histogram plots for both the observed and inferred library sizes, and optionally a scatter plot comparing them.
- Return type:
matplotlib.figure.Figure
- visualize_libsize_and_dispersion(infer_obj)
Generates a series plots to compare gene means, dispersion and library size.
- Parameters:
infer_obj (Inference) -- An instance of the Inference class containing necessary data and configurations. This object should have access to training datasets for original expressions and inference results to fetch inferred library sizes [[3]][[6]][[8]].
- Returns:
A matplotlib figure containing the 4-plot visualization.
- Return type:
matplotlib.figure.Figure