1212from itertools import combinations
1313
1414import numpy as np
15+ from scipy .stats import zscore
1516from sklearn .cluster import AgglomerativeClustering
1617
1718from . import BaseClusterEstimator
@@ -34,7 +35,15 @@ class ConsensusCluster(BaseClusterEstimator):
3435 """
3536
3637 def __init__ (
37- self , n_clusters , K = None , H = 100 , resample_proportion = 0.9 , linkage = "average" , cluster = AgglomerativeClustering
38+ self ,
39+ n_clusters ,
40+ K = None ,
41+ H = 100 ,
42+ resample_proportion = 0.9 ,
43+ linkage = "average" ,
44+ z_score = False ,
45+ z_cap = 3 , # ignored if z_score is False
46+ cluster = AgglomerativeClustering ,
3847 ):
3948 super ().__init__ ()
4049 assert 0 <= resample_proportion <= 1 , "proportion has to be between 0 and 1"
@@ -44,6 +53,9 @@ def __init__(
4453 self .resample_proportion = resample_proportion
4554 self .cluster = cluster
4655 self .linkage = linkage
56+ self .z_score = z_score
57+ assert z_cap > 0 , f"z_cap should be stricly positive, but got { z_cap } "
58+ self .z_cap = z_cap
4759
4860 def _internal_resample (self , data , proportion ):
4961 """Resamples the data.
@@ -62,6 +74,9 @@ def fit(self, data):
6274 Args:
6375 * data -> (examples,attributes) format
6476 """
77+ # zscore and clip
78+ if self .z_score :
79+ data = self ._z_score (data )
6580 Mk = np .zeros ((data .shape [0 ], data .shape [0 ]))
6681 Is = np .zeros ((data .shape [0 ],) * 2 )
6782 for _ in range (self .H ):
@@ -89,4 +104,11 @@ def fit(self, data):
89104
90105 def fit_predict (self , data ):
91106 """Predicts on the consensus matrix, for best found cluster number."""
107+ if self .z_score :
108+ data = self ._z_score (data )
92109 return self .cluster (n_clusters = self .n_clusters , linkage = self .linkage ).fit_predict (data )
110+
111+ def _z_score (self , data ):
112+ data = zscore (data , axis = 0 )
113+ data = np .clip (data , a_min = - self .z_cap , a_max = self .z_cap )
114+ return data
0 commit comments