feat(downsample):AnnData input and output for downsample_cells#349
feat(downsample):AnnData input and output for downsample_cells#349Chloe-Thangavelu wants to merge 5 commits intoFNLCR-DMAP:devfrom
Conversation
Modified downsample_cells function to accept anndata.AnnData objects as input. When an AnnData object is provided, .X` and .obs data are combined into a pandas DataFrame, before applying the rest of the downsampling function.
This commit modifies the 'downsample_cells' function and adds a helper function '_get_downsampled_indices' to provide cell downsampling capabilities for both AnnData objects and Pandas DataFrames.
This test now checks anndata objects are accepted, downsampled correctly, and returned as annadata objects.
Reordering the code to convert annotations to a list before extracting annotation information, ensuring it is in DataFrame format as required by subsequent downsample_cells code.
|
Summary: Changes:
|
There was a problem hiding this comment.
Pull Request Overview
This PR updates the downsample_cells function to support input as an AnnData object, converting its .X and .obs into a DataFrame for downsampling, while maintaining compatibility with pandas DataFrames.
- Added conversion logic for AnnData input
- Updated error handling and documentation for input types
- Introduced a new unit test to validate AnnData processing
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| tests/test_data_utils/test_downsample_cells.py | Added a new unit test to ensure downsample_cells correctly processes AnnData objects |
| src/spac/data_utils.py | Refactored downsample_cells logic to handle both pandas DataFrame and AnnData inputs, updating docstrings and internal variable usage |
| logging.basicConfig(level=logging.WARNING) | ||
| # Convert annotations to list if it's a string | ||
| if isinstance(annotations, str): | ||
| annotations = [annotations] | ||
|
|
||
| # Check if the columns to downsample on exist | ||
| missing_columns = [ | ||
| col for col in annotations if col not in input_data.columns | ||
| ] | ||
| if missing_columns: | ||
| raise ValueError( | ||
| f"Columns {missing_columns} do not exist in the dataframe" | ||
| ) | ||
|
|
||
| # If n_samples is None, return the input data without processing | ||
| if n_samples is None: | ||
| return input_data.copy() | ||
|
|
||
| # Combine annotations into a single column if multiple annotations | ||
| if len(annotations) > 1: |
There was a problem hiding this comment.
The repeated call to logging.basicConfig in both downsample_cells and _get_downsampled_indexes may cause configuration conflicts; consider configuring logging once at application startup.
| logging.basicConfig(level=logging.WARNING) | |
| # Convert annotations to list if it's a string | |
| if isinstance(annotations, str): | |
| annotations = [annotations] | |
| # Check if the columns to downsample on exist | |
| missing_columns = [ | |
| col for col in annotations if col not in input_data.columns | |
| ] | |
| if missing_columns: | |
| raise ValueError( | |
| f"Columns {missing_columns} do not exist in the dataframe" | |
| ) | |
| # If n_samples is None, return the input data without processing | |
| if n_samples is None: | |
| return input_data.copy() | |
| # Combine annotations into a single column if multiple annotations | |
| if len(annotations) > 1: | |
| # Combine annotations into a single column if multiple annotations | |
| if len(annotations) > 1: |
| else: | ||
| raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") | ||
|
|
There was a problem hiding this comment.
For consistency and clarity, update the error message to refer to 'AnnData' (with proper casing) instead of 'Anndata'.
| else: | |
| raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") | |
| else: | |
| raise TypeError("Input data must be a Pandas DataFrame or AnnData Object.") |
| @@ -586,62 +691,32 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, | |||
| annotation columns are provided. | |||
There was a problem hiding this comment.
The "combined_col_name" parameter is documented but never used in the code.
Since grouping_col is a pd.Series and not a new column in the cell_data DataFrame, the combined_col_name parameter isn't strictly necessary. May either remove or assign the name to the Series.
src/spac/data_utils.py
Outdated
| if len(annotations) > 1: | ||
| input_data[combined_col_name] = input_data[annotations].apply( | ||
| grouping_col = cell_data[annotations].apply( | ||
| lambda row: '_'.join(row.values.astype(str)), axis=1) |
There was a problem hiding this comment.
The apply method for combining annotations is readable but can be slow on very large datasets. A more performant, vectorized approach is to use str.cat or agg.
grouping_col = cell_data[annotations].astype(str).agg('_'.join, axis=1)
(Ensure all columns are string type first)
| lambda row: '_'.join(row.values.astype(str)), axis=1) | ||
| grouping_col = combined_col_name | ||
| else: | ||
| grouping_col = annotations[0] |
There was a problem hiding this comment.
The variable grouping_col is used inconsistently.
Suggested:
if len(annotations) > 1:
cell_data[combined_col_name] = cell_data[annotations].apply(
lambda row: '_'.join(row.values.astype(str)), axis=1)
grouping_col = combined_col_name
else:
grouping_col = annotations[0] # This is a string, not a Series
There was a problem hiding this comment.
Thank you, I will go through & add these suggestions today
| combined_col_name= '_combined_', | ||
| min_threshold= 5 | ||
| ) | ||
|
|
There was a problem hiding this comment.
The test calls the function with stratify=False and min_threshold=5. In the downsample_cells function, the min_threshold parameter is only used when stratify=True.
There was a problem hiding this comment.
The changes should be complete now. I have tested the code and it looks like its working. Let me know if any other fixes need to be made.
When multiple annotations are provided, a new temporary column (named by `combined_col_name`) is now explicitly added to the `cell_data` DataFrame, making `grouping_col` consistently a column name (string).
Replaced slow `DataFrame.apply` with the vectorized `DataFrame.astype(str).agg('_'.join, axis=1)` for combining annotations
bb4b7a3 to
1915dca
Compare
Summary:
Modified downsample_cells function to accept anndata.AnnData objects as input. When an AnnData object is provided, .X and .obs data are combined into a pandas DataFrame, before applying the rest of the downsampling function.
Changes: