Skip to content

Commit 59ac0dd

Browse files
committed
z_score
1 parent e633e3a commit 59ac0dd

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

src/flowsom/models/consensus_cluster.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,15 @@ class ConsensusCluster(BaseClusterEstimator):
3535
"""
3636

3737
def __init__(
38-
self, n_clusters, K=None, H=100, resample_proportion=0.9, linkage="average", cluster=AgglomerativeClustering
38+
self,
39+
n_clusters,
40+
K=None,
41+
H=100,
42+
resample_proportion=0.9,
43+
linkage="average",
44+
z_score=False,
45+
z_cap=3, # ignored if z_score is False
46+
cluster=AgglomerativeClustering,
3947
):
4048
super().__init__()
4149
assert 0 <= resample_proportion <= 1, "proportion has to be between 0 and 1"
@@ -45,6 +53,9 @@ def __init__(
4553
self.resample_proportion = resample_proportion
4654
self.cluster = cluster
4755
self.linkage = linkage
56+
self.z_score = z_score
57+
assert z_cap > 0, f"z_cap should be stricly positive, but got {z_cap}"
58+
self.z_cap = z_cap
4859

4960
def _internal_resample(self, data, proportion):
5061
"""Resamples the data.
@@ -64,8 +75,8 @@ def fit(self, data):
6475
* data -> (examples,attributes) format
6576
"""
6677
# zscore and clip
67-
data = zscore(data, axis=0)
68-
data = np.clip(data, a_min=-3, a_max=3)
78+
if self.z_score:
79+
data = self._z_score(data)
6980
Mk = np.zeros((data.shape[0], data.shape[0]))
7081
Is = np.zeros((data.shape[0],) * 2)
7182
for _ in range(self.H):
@@ -93,6 +104,11 @@ def fit(self, data):
93104

94105
def fit_predict(self, data):
95106
"""Predicts on the consensus matrix, for best found cluster number."""
96-
data = zscore(data, axis=0)
97-
data = np.clip(data, a_min=-3, a_max=3)
107+
if self.z_score:
108+
data = self._z_score(data)
98109
return self.cluster(n_clusters=self.n_clusters, linkage=self.linkage).fit_predict(data)
110+
111+
def _z_score(self, data):
112+
data = zscore(data, axis=0)
113+
data = np.clip(data, a_min=-self.z_cap, a_max=self.z_cap)
114+
return data

0 commit comments

Comments
 (0)