Source code for drcme.bin.run_refine_unstable_coclusters

"""
Script to merge cells from unstable clusters into the most similar stable ones.

The script identifies unstable clusters by their Jaccard coefficients falling below a specified
threshold (``unstable_threshold``). It then determines how similar the unstable cluster is
to other stable clusters. If it is too dissimilar, it is kept as its own cluster. Otherwise,
the unstable cluster is dissolved and its cells are assigned to their best-matching
clusters.

It determines whether to dissolve a cluster by calculating whether or not each cell in the
unstable cluster has a good match to a stable cluster (i.e., the co-clustering rate exceeds
``coclust_threshold``). If enough of the cells of the unstable cluster have good matches
(the fraction of matching cells exceeds ``pct_needed``), the cluster is dissolved and the
cells of that cluster are reassigned to stable clusters.

.. autoclass:: RefineParameters

.. autofunction:: stable_match_rates
.. autofunction:: match_rates_for_unstable_clusters
.. autofunction:: new_labels_for_dissolved_cluster

"""

import numpy as np
import pandas as pd
import argschema as ags
import logging


[docs]class RefineParameters(ags.ArgSchema): """Parameter schema for unstable cluster refinement""" cocluster_matrix_file = ags.fields.InputFile( description="File path for co-clustering matrix") jaccards_file = ags.fields.InputFile( description="File path for Jaccard coefficients") cluster_labels_file = ags.fields.InputFile( description="File path for cluster labels") refined_labels_file = ags.fields.OutputFile( description="Output file path for refined cluster labels (as integers)") refined_text_labels_file = ags.fields.OutputFile( description="Output file path for refined cluster labels (as strings with me_prefix)") refined_ordering_file = ags.fields.OutputFile( description="Output file path for refined cluster labels") unstable_threshold = ags.fields.Float( description="Threshold for Jaccard coefficients to determine stability", default=0.5) coclust_threshold = ags.fields.Float( description="Threshold for co-clustering rate to be considered a match to another cluster", default=0.4) pct_needed = ags.fields.Float( description="Minimum fraction of matching cells to dissolve a cluster", default=0.33) me_prefix = ags.fields.String( description="prefix for refined cluster text labels")
[docs]def stable_match_rates(clust_labels, shared, stable_clusters): """Calculate the co-clustering rates of cells within the stable clusters Parameters ---------- clust_labels : (n, ) array Cluster labels for the `n` samples shared : (n, n) array Co-clustering rates between all `n` samples stable_clusters : list List of labels of the stable clusters Returns ------- list Returns list of the average within-cluster co-clustering rates for every cell found within stable clusters """ rates = [] for cl in stable_clusters: cl_mask = cl == clust_labels cl_cells = np.flatnonzero(cl_mask) for cell_index in cl_cells: my_mask = np.ones(shared.shape[0]).astype(bool) my_mask[cell_index] = False rates.append(np.mean(shared[cell_index, :][my_mask & cl_mask])) return rates
[docs]def match_rates_for_unstable_clusters(unstable_clusters, stable_clusters, clust_labels, shared, threshold): """Calculate the fraction of cells in unstable clusters that match to a stable cluster. The highest co-clustering rate with a stable cluster is calculated for each cell in an unstable cluster. If that rate exceeds ``threshold``, that cell is categorized as matching another cluster. The fraction of matching cells is returned for each unstable cluster. Parameters ---------- unstable_clusters : list List of labels of the unstable clusters stable_clusters : list List of labels of the stable clusters clust_labels : (n, ) array Cluster labels for the `n` samples shared : (n, n) array Co-clustering rates between all `n` samples threshold : float Minimum co-clustering rate to be considered a match with another cluster Returns ------- dict Dictionary of unstable clusters (keys) and their fractions of matching cells (values) """ results = [] for cl in unstable_clusters: cl_cells = np.flatnonzero(cl == clust_labels) if len(cl_cells) == 0: continue best_match_rates = [] for cell_index in cl_cells: other_rates = [] for other_cl in stable_clusters: other_mask = other_cl == clust_labels other_rates.append(np.mean(shared[cell_index, :][other_mask])) best_match_rates.append(np.max(other_rates)) results.append((cl, np.sum(np.array(best_match_rates) > threshold).astype(float) / len(best_match_rates))) return dict(results)
[docs]def new_labels_for_dissolved_cluster(cl, clust_labels, shared, stable_clusters): """ Relabel the cells in a dissolved cluster with their new assignments Parameters ---------- cl : int Cluster that will be dissolved clust_labels : (n, ) array Cluster labels for the `n` samples shared : (n, n) array Co-clustering rates between all `n` samples stable_clusters : list List of labels of the stable clusters Returns ------- (n, ) array Array with updated cluster labels """ cl_cells = np.flatnonzero(cl == clust_labels) new_labelling = [] for cell_index in cl_cells: other_rates = [] for other_cl in stable_clusters: other_mask = other_cl == clust_labels other_rates.append(np.mean(shared[cell_index, :][other_mask])) new_cl = stable_clusters[np.argmax(other_rates)] new_labelling.append((cell_index, new_cl)) new_labels = clust_labels.copy() for ci, new_cl in new_labelling: new_labels[ci] = new_cl return new_labels
def main(cocluster_matrix_file, jaccards_file, cluster_labels_file, refined_labels_file, refined_text_labels_file, refined_ordering_file, unstable_threshold, coclust_threshold, pct_needed, me_prefix, **kwargs): """ Main runner function for script. See argschema input parameters for argument descriptions. """ shared = np.loadtxt(cocluster_matrix_file) jaccards = np.loadtxt(jaccards_file) clust_labels_df = pd.read_csv(cluster_labels_file, index_col=0) print(clust_labels_df.head()) clust_labels = clust_labels_df["0"].values unstable_clusters = np.flatnonzero(jaccards.mean(axis=1) < unstable_threshold) stable_clusters = np.flatnonzero(jaccards.mean(axis=1) >= unstable_threshold) logging.info("Found {:d} stable clusters".format(len(stable_clusters))) logging.info("Found {:d} unstable clusters".format(len(unstable_clusters))) stable_rates = stable_match_rates(clust_labels, shared, stable_clusters) logging.info("Using coclustering threshold of {:g}".format(coclust_threshold)) logging.info("cf. stable rates 5th percentile {:g}".format(np.percentile(stable_rates, 5))) keep_going = True refined_clust_labels = clust_labels.copy() while keep_going: mr = match_rates_for_unstable_clusters(unstable_clusters, stable_clusters, refined_clust_labels, shared, threshold=coclust_threshold) if len(list(mr.keys())) == 0: break cl_for_dissolve = max(mr, key=lambda key: mr[key]) if mr[cl_for_dissolve] < pct_needed: break logging.info("Dissolving {:d}".format(int(cl_for_dissolve))) refined_clust_labels = new_labels_for_dissolved_cluster(cl_for_dissolve, refined_clust_labels, shared, stable_clusters) refined_order = np.lexsort((clust_labels, refined_clust_labels)) refined_relabel = {v: i + 1 for i, v in enumerate(np.sort(np.unique(refined_clust_labels)))} refined_text_labels = ["{:s}_{:d}".format(me_prefix, refined_relabel[v]) for v in refined_clust_labels] pd.DataFrame(refined_clust_labels, index=clust_labels_df.index.values).to_csv(refined_labels_file) pd.DataFrame(refined_text_labels, index=clust_labels_df.index.values).to_csv(refined_text_labels_file) np.savetxt(refined_ordering_file, refined_order, fmt="%d") if __name__ == "__main__": module = ags.ArgSchemaParser(schema_type=RefineParameters) main(**module.args)