_utils.py module
- _utils.To(dev, objects)
Transfers each object in the provided dictionary to the specified device using PyTorch's .to() method.
- Parameters:
dev (torch.device or str) -- The target device to which the tensors or models should be moved. It can be a torch.device object or a string specifying the device (e.g., 'cpu', 'cuda').
objects (dict) -- A dictionary where each key-value pair consists of a name (key) and a PyTorch tensor, model, or a list of tensors/models (value) [[6]][[7]][[8]].
- Returns:
A dictionary with the same keys as the input, but with each tensor or model moved to the specified device.
- Return type:
dict
- _utils.ZINB_expected_value(mu, logits, distribution)
Calculate the expected value from a Zero-Inflated Negative Binomial (ZINB) or Negative Binomial (NB) distribution.
- Parameters:
mu (torch.Tensor) -- The mean of the negative binomial distribution.
logits (torch.Tensor) -- The logits for the zero-inflation component.
distribution (str) -- The type of distribution, either "ZINB" or "NB".
- Returns:
The expected values computed from the NB/ZINB distribution.
- Return type:
torch.Tensor
- _utils.adjusted_categorical_correction_variable(df, cat_var, target_value)
One-hot encodes a specific column in a DataFrame for a target value. It provides a one-hot encoded representation where only the target value is active.
- Parameters:
df (pandas.DataFrame) -- The input DataFrame containing the categorical variable.
cat_var (str) -- The column name of the categorical variable to be encoded.
target_value (int or str) -- The specific category value to be assigned a fixed index.
- Returns:
A tuple containing:
indices (torch.Tensor): The adjusted indices for the categorical variable.
named_cat_map (dict): A dictionary mapping the categorical variable name to a dictionary of category-to-index mappings.
num_unique (int): The number of unique categories in the categorical variable.
- Return type:
tuple
- _utils.calculate_and_sort_by_iqr(expr, gene_names=None, force_pam50=False, force_bc_genes=False)
Function to calculate IQR in a table by column, then sort the table
- Parameters:
expr (numpy.ndarray) -- A table of expression data where columns represent a gene and rows are samples.
gene_names (list of str, optional) -- Name of genes (only needed if force_pam50 is True).
force_pam50 (bool, optional) -- Flag indicating whether to force the use of PAM50 genes (not used in the function).
force_bc_genes (bool, optional) -- Flag indicating whether to force the use of BC genes (not used in the function).
- Returns:
Sorted indices for the columns of the input table based on IQR in descending order.
- Return type:
numpy.ndarray
- _utils.categorical_correction_variable(df, cat_var)
Prepare a categorical variable for one-hot encoding.
This function prepares a categorical variable in a DataFrame for one-hot encoding. It returns the indices, a mapping of categories to indices, and the number of unique categories.
- Parameters:
df (pandas.DataFrame) -- The input DataFrame containing the categorical variable.
cat_var (str) -- The column name of the categorical variable to be encoded.
- Returns:
A tuple containing:
indices (torch.Tensor): The indices for the categorical variable.
named_cat_map (dict): A dictionary mapping the categorical variable name to a dictionary of category-to-index mappings.
num_unique (int): The number of unique categories in the categorical variable.
- Return type:
tuple
- _utils.check_folder_access(path)
Checks if the specified folder path exists, is accessible, and adheres to the required format.
- Parameters:
path (str) -- The filesystem path of the directory to check. This path should be absolute or relative to the current working directory.
- Raises:
FileNotFoundError -- If the directory does not exist at the specified path.
PermissionError -- If the directory exists but does not have read permissions enabled.
ValueError -- If the directory path does not end with a '/'.
- _utils.check_for_nans(obj, msg)
Checks for NaN values in a PyTorch tensor and raises an error if found.
- Parameters:
obj (torch.Tensor) -- The tensor to check for NaN values.
msg (str) -- The error message to display if NaN values are found.
- Raises:
ValueError -- If NaN values are found in the obj tensor.
- _utils.check_int_and_uniq(file_nums)
Checks if list contains unique integers.
- Parameters:
file_nums (list of int) -- A list of values from filenames after the phrase "tau".
- Raises:
ValueError -- If any element in the list is not an integer or if integers are not unique.
- _utils.copy_file(source_path, destination_path)
Copies a file from the source path to the destination path.
- Parameters:
source_path (str) -- The path to the source file that needs to be copied.
destination_path (str) -- The path to the destination where the file should be copied.
- Raises:
IOError -- If an error occurs during the file copying process, such as the source file not existing, no permission to read the file, or issues with the destination path.
Exception -- If any other unforeseen exception occurs during the copying process.
- _utils.ensure_directory(directory_path)
Ensures that a directory exists at the specified path. If the directory does not exist, it is created with specific permissions. If it already exists, the function does nothing.
- Parameters:
directory_path (str) -- The filesystem path where the directory should exist.
- _utils.logging_helper(dic, log)
Logs tensor values along with their corresponding keys from a dictionary.
- Parameters:
dic (dict) -- The dictionary containing keys and tensor values to be logged.
log (logging.Logger) -- Logger used for outputting tensor information.
- _utils.logging_tensor(log, tensor, msg)
Logs the details of a tensor or a list of tensors, including its size, dtype, minimum, maximum values, and checks for NaNs or infinities.
- Parameters:
log (logging.Logger) -- The logging object used for logging the tensor details.
tensor (torch.Tensor or list[torch.Tensor]) -- The tensor or list of tensors to be logged.
msg (str) -- The message to be printed before logging the tensor details.
- _utils.multi_logging_tensor(things)
Logs details of multiple tensors from a list of lists, where each sublist contains parameters intended for the logging_tensor function. This allows for batch logging of tensor details for efficient monitoring of tensor states across different stages or components of a model.
- Parameters:
things (list of tuple) --
A list of tuples, where each tuple contains three elements:
log (logging.Logger): The logger object used for logging.
tensor (torch.Tensor or list[torch.Tensor]): The tensor(s) to be logged.
msg (str): A descriptive message that precedes the tensor details in the log.
- _utils.one_hot_encode(column_tensor)
Perform one-hot encoding on a tensor.
This function takes a tensor representing a categorical variable and performs one-hot encoding on it. If the categorical variable is invariant (i.e., has only one unique value), it returns a tensor of ones with the same number of rows as the input tensor and one column. Otherwise, it applies one-hot encoding using PyTorch's F.one_hot function.
- Parameters:
column_tensor (torch.Tensor) -- The tensor representing the categorical variable to be one-hot encoded.
- Returns:
The one-hot encoded tensor.
- Return type:
torch.Tensor
- _utils.plot_progression(target, title)
Plots the progression of average loss across mini-batches and epochs.
- Parameters:
target (list or numpy.ndarray) -- A list or NumPy array containing average loss values at each mini-batch or epoch.
title (str) -- The title to display on the plot.
- _utils.plot_progression_all(losses, epoch, x_dim=2, y_dim=4, override=False, file_path=None)
Plots the progression of multiple types of loss metrics during training and validation across epochs.
- Parameters:
epoch (int) -- The current epoch number for which the losses are being plotted.
x_dim (int, optional) -- The number of rows in the subplot grid. Default is 2.
y_dim (int, optional) -- The number of columns in the subplot grid. Default is 4.
override (bool, optional) -- If True, overrides the default subplot grid dimensions and sets them based on the number of unique plots. Default is False.
file_path (str, optional) -- The file path where the plot should be saved. If not specified, the plot is saved as 'loss_progression_epoch_{epoch}.pdf'.
- _utils.print_loss_table(data, log)
Prints a well-formatted table for loss metrics or other numerical data.
- Parameters:
data (list of lists) -- The data to be printed. Each sublist is considered a row in the table. The first row is assumed to be the header. Cells can contain numerical values (int, float), NumPy arrays, or strings.
log (logging.Logger) -- The logger object used for printing the table.
- _utils.reparameterize_gaussian(mu, var)
Reparameterizes a Gaussian distribution and samples from it.
- Parameters:
mu (torch.Tensor) -- The mean of the Gaussian distribution.
var (torch.Tensor) -- The variance of the Gaussian distribution.
- Returns:
A sample from the Gaussian distribution parameterized by mu and var, with the same shape as mu and var.
- Return type:
torch.Tensor
- _utils.sanity_check_on_configs(preffect_con=None, train_ds_con=None, valid_ds_con=None)
Performs a series of assertions on configuration variables to ensure they meet predefined criteria necessary for PREFFECT to function.
- Parameters:
preffect_con (dict, optional) -- The configuration dictionary for the PREFFECT system. Defaults to None.
train_ds_con (dict, optional) -- The configuration dictionary for the training dataset. Defaults to None.
valid_ds_con (dict, optional) -- The configuration dictionary for the validation dataset. Defaults to None.
- Raises:
PreffectE -- If any of the configuration variables do not pass their respective checks, indicating a critical mismatch in the expected environment setup.
- _utils.selective_one_hot(df, cat_var)
One-hot encodes a specific column in a DataFrame and removes all other columns.
- Args:
df (DataFrame): The input DataFrame. cat_var (str): The column name to be one-hot encoded.
- Returns:
DataFrame: A DataFrame containing only the one-hot encoded column of cat_var.
- _utils.set_seeds(config_seed)
Sets the seed for generating random numbers to ensure reproducibility across various random number generators.
- Parameters:
config_seed (int, optional) -- The seed value to use for all random number generators. If None, no seed is set.
- _utils.target_specific_one_hot(df, cat_var, target_value)
One-hot encodes a specific column in a DataFrame for a target value. It provides a one-hot encoded representation where only the target value is active.
- Parameters:
df (pandas.DataFrame) -- The input DataFrame containing the categorical variable.
cat_var (str) -- The column name to be one-hot encoded.
target_value (int or str) -- The specific category value in the column to encode as 1; all others are set to 0.
- Returns:
A tensor with the one-hot encoded column where only the target value is 1, and all other category values are 0.
- Return type:
torch.Tensor
- _utils.torch_mtx_unbatching(mtx, idx_batches, dataset, device)
Checks if list contains unique integers
- Parameters:
mtx (torch.Tensor) -- A list of mini-batched torch matrices.
idx_batches (list of tuple of torch.Tensor) -- A list of mini-batch indices, where each element is a tuple containing a tensor of indices.
dataset (Dataset) -- The dataset object used to retrieve ghost indices.
device (torch.device) -- The device on which to perform the computations.
- Returns:
The rearranged matrix with rows sorted based on the indices.
- Return type:
torch.Tensor
- _utils.umap_draw_latent(results_from_forward, batch_info)
Visualize the latent space using UMAP dimensionality reduction.
This function takes the latent variables from the forward pass results and applies UMAP (Uniform Manifold Approximation and Projection) dimensionality reduction to visualize the latent space. It plots the reduced latent space in a 2D scatter plot, where each point represents a sample and is colored based on the corresponding batch information.
- Parameters:
results_from_forward (dict) -- A dictionary containing the results from the forward pass, including the latent variables.
batch_info (numpy.ndarray) -- An array containing the batch information for each sample.
- _utils.update_composite_configs(configs)
Update the composite configuration dictionary with derived paths.
This function takes a configuration dictionary configs and updates it with derived paths for logs, inference results, and output files. It ensures that the output_path ends with a forward slash ('/') and constructs the derived paths by joining the output_path with the respective subdirectories using os.path.join and os.sep.
- Parameters:
configs (dict, optional) -- The configuration dictionary to be updated. If None, the function returns None.
- Returns:
The updated configuration dictionary with derived paths, or None if the input configs is None.
- Return type:
dict, optional