-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathscript.py
More file actions
248 lines (206 loc) · 12 KB
/
script.py
File metadata and controls
248 lines (206 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""BBKNN (TS) code for removing batch effects.
This is the top-performing batch correction implementation discovered by the AI
system described in https://arxiv.org/abs/2509.06503.
"""
## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
'input': 'resources_test/.../input.h5ad',
'output': 'output.h5ad'
}
meta = {
'name': 'bbknn_ts'
}
## VIASH END
################################################################################
####### LLM-written implementation of batch correction code starts here. #######
################################################################################
from typing import Any
from sklearn.decomposition import TruncatedSVD
from sklearn.neighbors import NearestNeighbors
from scipy.sparse import lil_matrix, csr_matrix
import numpy as np
import scanpy as sc
import anndata as ad
import heapq # For efficiently getting top K elements from merged lists
# Define parameters for the config.
# These values are chosen to balance computational cost and integration performance
# for datasets with up to ~300k cells and 2k genes.
config = {
'n_pca_components': 100, # Number of PCA components. Recommended: 50-200.
# Captures sufficient variance while reducing dimensionality.
'n_neighbors_per_batch': 10, # Number of neighbors to find within each batch. Recommended: 5-15.
# This defines the local batch context for each cell.
'total_k_neighbors': 50, # Total number of nearest neighbors to retain for the final graph. Recommended: 15-100.
# This forms the global batch-integrated graph.
}
def eliminate_batch_effect_fn(
adata: ad.AnnData, config: dict[str, Any]
) -> ad.AnnData:
# Create a copy to ensure the original input adata remains unchanged.
adata_integrated = adata.copy()
# --- Preprocessing: Normalize, log-transform, scale ---
# These are standard initial steps for scRNA-seq data.
# Use adata.X which contains raw counts.
sc.pp.normalize_total(adata_integrated, target_sum=1e4)
sc.pp.log1p(adata_integrated)
sc.pp.scale(adata_integrated, max_value=10) # Clip values to avoid extreme outliers
# --- Batch Correction: ComBat on the gene expression matrix ---
# This step applies a more robust linear model-based batch correction
# directly on the gene expression data before dimensionality reduction.
# ComBat modifies adata_integrated.X in place.
sc.pp.combat(adata_integrated, key='batch')
# --- Dimensionality Reduction: PCA on the ComBat-corrected data ---
# n_comps cannot exceed min(n_obs - 1, n_vars). Robustly handle small datasets.
n_pca_components = config.get('n_pca_components', 100)
actual_n_pca_components = min(n_pca_components, adata_integrated.n_vars, adata_integrated.n_obs - 1)
# Handle edge cases for PCA and graph construction where data is too small.
# If PCA cannot be run meaningfully, return a minimal AnnData object to avoid errors.
if actual_n_pca_components <= 0 or adata_integrated.n_obs <= 1:
print(f"Warning: Too few observations ({adata_integrated.n_obs}) or dimensions ({adata_integrated.n_vars}) for PCA/graph construction. Returning trivial embedding.")
# Provide a placeholder embedding and empty graph structure.
adata_integrated.obsm['X_emb'] = np.zeros((adata_integrated.n_obs, 1))
adata_integrated.obsp['connectivities'] = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs))
adata_integrated.obsp['distances'] = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs))
adata_integrated.uns['neighbors'] = {
'params': {
'n_neighbors': 0,
'method': 'degenerate',
'n_pcs': 0,
'n_neighbors_per_batch': 0,
'pca_batch_correction': 'none',
},
'connectivities_key': 'connectivities',
'distances_key': 'distances',
}
return adata_integrated
sc.tl.pca(adata_integrated, n_comps=actual_n_pca_components, svd_solver='arpack')
# Set the ComBat-corrected PCA embedding as the integrated output embedding.
# This 'X_emb' will be directly evaluated by metrics like ASW, LISI, PCR.
adata_integrated.obsm['X_emb'] = adata_integrated.obsm['X_pca']
# --- Custom Batch-Aware Nearest Neighbors Graph Construction ---
# This implements the expert advice: find neighbors independently within batches, then merge.
# This part of the code remains largely the same, but now operates on the
# ComBat-corrected PCA embedding (adata_integrated.obsm['X_emb']).
k_batch_neighbors = config.get('n_neighbors_per_batch', 10)
total_k_neighbors = config.get('total_k_neighbors', 50)
# A list of dictionaries to store unique neighbors and their minimum distances for each cell.
# Using dictionaries allows efficient updating if a cell is found as a neighbor from multiple batches.
merged_neighbors_per_cell = [{} for _ in range(adata_integrated.n_obs)]
# Group cell indices by batch for efficient querying.
batches = adata_integrated.obs['batch'].values
unique_batches = np.unique(batches)
batch_to_indices = {b: np.where(batches == b)[0] for b in unique_batches}
# Pre-fit NearestNeighbors models for each batch's data using the corrected PCA embedding.
# This avoids refitting the model for every query.
batch_nn_models = {}
for b_id in unique_batches:
batch_cell_indices = batch_to_indices[b_id]
# Ensure there are enough cells to fit a NearestNeighbors model (at least k_batch_neighbors + 1 for self-exclusion, or just > 0 for min k=1)
if len(batch_cell_indices) > 0:
# Fit with a k that is at most the batch size to avoid errors if k_batch_neighbors is too high for a small batch.
k_fit_effective = min(k_batch_neighbors + 1, len(batch_cell_indices)) # +1 to ensure self-loop can be found and excluded
if k_fit_effective > 0: # Only fit if there are points available
nn_model = NearestNeighbors(n_neighbors=k_fit_effective, metric='euclidean', algorithm='auto')
nn_model.fit(adata_integrated.obsm['X_emb'][batch_cell_indices])
batch_nn_models[b_id] = nn_model
# Iterate through all possible query batches and target batches to find neighbors.
for query_batch_id in unique_batches:
query_global_indices = batch_to_indices[query_batch_id]
if len(query_global_indices) == 0:
continue # Skip empty query batches
query_data = adata_integrated.obsm['X_emb'][query_global_indices]
for target_batch_id in unique_batches:
if target_batch_id not in batch_nn_models:
continue # Skip target batches that were too small to fit an NN model
nn_model = batch_nn_models[target_batch_id]
target_global_indices = batch_to_indices[target_batch_id]
# Ensure n_neighbors does not exceed the number of points in the target batch.
k_for_query = min(k_batch_neighbors, len(target_global_indices) -1) # -1 to avoid finding self as neighbor if batch is query batch
if k_for_query <= 0: # No valid neighbors can be found in this target batch
continue
# Query neighbors for all cells in the current query batch against the target batch's data.
distances, indices_in_target_batch = nn_model.kneighbors(query_data, n_neighbors=k_for_query, return_distance=True)
for i_query_local in range(len(query_global_indices)):
current_cell_global_idx = query_global_indices[i_query_local]
dists_for_cell = distances[i_query_local]
global_neighbors_for_cell = target_global_indices[indices_in_target_batch[i_query_local]]
for k_idx in range(len(global_neighbors_for_cell)):
neighbor_global_idx = global_neighbors_for_cell[k_idx]
dist = dists_for_cell[k_idx]
# Exclude self-loops: a cell should not be its own neighbor in graph construction.
if neighbor_global_idx == current_cell_global_idx:
continue
# Store neighbor and its distance. If already present, keep the minimum distance (closest connection).
if (neighbor_global_idx not in merged_neighbors_per_cell[current_cell_global_idx] or
dist < merged_neighbors_per_cell[current_cell_global_idx][neighbor_global_idx]):
merged_neighbors_per_cell[current_cell_global_idx][neighbor_global_idx] = dist
# Convert collected neighbors and distances into sparse matrices.
rows = []
cols = []
data_distances = []
for i in range(adata_integrated.n_obs):
# Retrieve all candidate neighbors for cell 'i', sort by distance, and take the top 'total_k_neighbors'.
current_cell_candidates = list(merged_neighbors_per_cell[i].items())
if not current_cell_candidates: # If a cell has no valid neighbors after all filtering
continue
# Use heapq for efficient selection of the smallest distances.
selected_neighbors = heapq.nsmallest(total_k_neighbors, current_cell_candidates, key=lambda item: item[1])
for neighbor_idx, dist in selected_neighbors:
rows.append(i)
cols.append(neighbor_idx)
data_distances.append(dist)
# Create distance matrix. Handle case with no neighbors found at all for the entire dataset.
if not rows:
distances_matrix = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs))
else:
distances_matrix = csr_matrix((data_distances, (rows, cols)), shape=(adata_integrated.n_obs, adata_integrated.n_obs))
# Symmetrize the distance matrix: if A is a neighbor of B, then B is also a neighbor of A,
# with the distance being the maximum of the two observed distances (ensures undirected graph).
distances_matrix = distances_matrix.maximum(distances_matrix.T)
distances_matrix.eliminate_zeros() # Remove any explicit zeros created by max operation
# Create connectivities matrix (binary representation of connections).
connectivities_matrix = distances_matrix.copy()
connectivities_matrix.data[:] = 1.0 # All non-zero entries become 1.0 (connected).
connectivities_matrix.eliminate_zeros()
connectivities_matrix = connectivities_matrix.astype(float)
# Store the custom graph in adata.obsp. These keys are used by scib metrics.
adata_integrated.obsp['connectivities'] = connectivities_matrix
adata_integrated.obsp['distances'] = distances_matrix
# Store parameters in adata.uns['neighbors'] for completeness and scanpy/scib compatibility.
adata_integrated.uns['neighbors'] = {
'params': {
'n_neighbors': total_k_neighbors,
'method': 'custom_batch_aware_combat_pca', # Reflects the integration strategy
'metric': 'euclidean',
'n_pcs': actual_n_pca_components,
'n_neighbors_per_batch': k_batch_neighbors,
'pca_batch_correction': 'combat', # Indicates ComBat was applied before PCA
},
'connectivities_key': 'connectivities',
'distances_key': 'distances',
}
return adata_integrated
################################################################################
######## LLM-written implementation of batch correction code ends here. ########
################################################################################
################################################################################
# Start of human-written code. #################################################
################################################################################
# This is just boilerplate to satisfy the OpenProblems-codebase-specific setup
# for running evaluation.
import sys
sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata
print('Read input', flush=True)
input_adata = read_anndata(
par['input'],
X='layers/counts',
obs='obs',
var='var',
uns='uns'
)
output = eliminate_batch_effect_fn(input_adata, config=config)
output.uns['method_id'] = 'bbknn_ts'
output.write_h5ad(par['output'], compression='gzip')