Source code for simba.tools._pbg

"""PyTorch-BigGraph (PBG) for learning graph embeddings"""

import numpy as np
import pandas as pd
import os
import json

from pathlib import Path
import attr
from torchbiggraph.config import (
    add_to_sys_path,
    ConfigFileLoader
)
from torchbiggraph.converters.importers import (
    convert_input_data,
    TSVEdgelistReader
)
from torchbiggraph.train import train
from torchbiggraph.util import (
    set_logging_verbosity,
    setup_logging,
    SubprocessInitializer,
)

from .._settings import settings


[docs] def gen_graph( list_CP=None, list_PM=None, list_PK=None, list_CG=None, list_CC=None, list_adata=None, prefix_C='C', prefix_P='P', prefix_M='M', prefix_K='K', prefix_G='G', prefix='E', layer='simba', copy=False, dirname='graph0', add_edge_weights=None, use_highly_variable=True, use_top_pcs=True, use_top_pcs_CP=None, use_top_pcs_PM=None, use_top_pcs_PK=None ): """Generate graph for PBG training. Observations and variables of each Anndata object will be encoded as nodes (entities). The non-zero values in `.layers['simba']` (by default) or `.X` (if `.layers['simba']` does not exist) indicate the edges between nodes. The values of `.layers['simba']` or `.X` will be used as the edge weights if `add_edge_weights` True. When `list_adata` is specified, nodes between different anndata objects in it will be automatically matched based on `.obs_names` and `.var_names`. It is a generalized parameter that encompasses data-specific parameters such as list_CG, list_CP, list_PK, etc. Each anndata object indicates one or more relation types. It also generates an accompanying file 'entity_alias.tsv' to map the indices to the aliases used in the graph. Note when `add_edge_weights` is True, `list_CG` will only generate one relation of cells and genes, as opposed to multiple relations based on discretized levels. Parameters ---------- list_CP: `list`, optional (default: None) A list of anndata objects that store ATAC-seq data (Cells by Peaks) The default weight of cell-peak relation type is 1.0. Ignored when `list_adata` is specified. list_PM: `list`, optional (default: None) A list of anndata objects that store relation between Peaks and Motifs. Ignored when `list_adata` is specified. list_PK: `list`, optional (default: None) A list of anndata objects that store relation between Peaks and Kmers Ignored when `list_adata` is specified. list_CG: `list`, optional (default: None) A list of anndata objects that store RNA-seq data (Cells by Genes). Ignored when `list_adata` is specified. list_CC: `list`, optional (default: None) A list of anndata objects that store relation between Cells from two conditions Ignored when `list_adata` is specified. list_adata: `list`, optional (default: None) A list of anndata objects. `.obs_names` and `.var_names` between anndata objects will be automatically matched. If `list_adata` is specified, the other lists including `list_CP`, `list_PM`,`list_PK`, `list_CG`, `list_CC` will be ignored. prefix_C: `str`, optional (default: 'C') Prefix to indicate the entity type of cells Ignored when `list_adata` is specified. prefix_G: `str`, optional (default: 'G') Prefix to indicate the entity type of genes Ignored when `list_adata` is specified. prefix: `str`, optional (default: 'E') Prefix to indicate general entities in `list_adata` layer: `str`, optional (default: 'simba') The layer in AnnData to use for constructing the graph. If `layer` is None or the specificed layer does not exist, `.X` in AnnData will be used instead. dirname: `str`, (default: 'graph0') The name of the directory in which each graph will be stored add_edge_weights: `bool`, optional (default: None) If True, the column of edge weigths will be added. If `list_adata` is specified, `add_edge_weights` is set True by default. Otherwise, it is set False. use_highly_variable: `bool`, optional (default: True) Use highly variable genes. Only valid for `list_CG`. Ignored when `list_adata` is specified. use_top_pcs: `bool`, optional (default: True) Use top-PCs-associated features for CP, PM, PK Only valid for `list_PM`,`list_PK`, `list_CP`. Ignored when `list_adata` is specified. use_top_pcs_CP: `bool`, optional (default: None) Use top-PCs-associated features for CP Only valid for `list_CP`. Once specified, it will overwrite `use_top_pcs` Ignored when `list_adata` is specified. use_top_pcs_PM: `bool`, optional (default: None) Use top-PCs-associated features for PM Only valid for `list_PM`. Once specified, it will overwrite `use_top_pcs` Ignored when `list_adata` is specified. use_top_pcs_PK: `bool`, optional (default: None) Use top-PCs-associated features for PK Only valid for `list_PK`. Once specified, it will overwrite `use_top_pcs Ignored when `list_adata` is specified. copy: `bool`, optional (default: False) If True, it returns the graph file as a data frame Returns ------- If `copy` is True, edges: `pd.DataFrame` The edges of the graph used for PBG training. Each line contains information about one edge. Using tabs as separators, each line contains the identifiers of the source entities, the relation types and the target entities. updates `.settings.pbg_params` with the following parameters. entity_path: `str` The path of the directory containing entity count files. edge_paths: `list` A list of paths to directories containing (partitioned) edgelists. Typically a single path is provided. entities: `dict` The entity types. relations: `list` The relation types. updates `.settings.graph_stats` with the following parameters. dirname: `dict` Statistics of input graph """ if sum(list(map(lambda x: x is None, [list_CP, list_PM, list_PK, list_CG, list_CC, list_adata]))) == 6: return 'No graph is generated' filepath = os.path.join(settings.workdir, 'pbg', dirname) settings.pbg_params['entity_path'] = \ os.path.join(filepath, "input/entity") settings.pbg_params['edge_paths'] = \ [os.path.join(filepath, "input/edge"), ] settings.pbg_params['entity_path'] = \ os.path.join(filepath, "input/entity") if not os.path.exists(filepath): os.makedirs(filepath) if add_edge_weights is None: if list_adata is None: add_edge_weights = False else: add_edge_weights = True if list_adata is not None: id_ent = pd.Index([]) # ids of all entities dict_ent_type = dict() ctr_ent = 0 # counter for entity types entity_alias = pd.DataFrame(columns=['alias']) dict_graph_stats = dict() if add_edge_weights: col_names = ["source", "relation", "destination", "weight"] else: col_names = ["source", "relation", "destination"] df_edges = pd.DataFrame(columns=col_names) settings.pbg_params['relations'] = [] for ctr_rel, adata_ori in enumerate(list_adata): obs_names = adata_ori.obs_names var_names = adata_ori.var_names if len(set(obs_names).intersection(id_ent)) == 0: prefix_i = f'{prefix}{ctr_ent}' id_ent = id_ent.union(adata_ori.obs_names) entity_alias_obs = pd.DataFrame( index=obs_names, columns=['alias'], data=[f'{prefix_i}.{x}' for x in range(len(obs_names))]) settings.pbg_params['entities'][ prefix_i] = {'num_partitions': 1} dict_ent_type[prefix_i] = obs_names entity_alias = pd.concat( [entity_alias, entity_alias_obs], ignore_index=False) obs_type = prefix_i ctr_ent += 1 else: for k, item in dict_ent_type.items(): if len(set(obs_names).intersection(item)) > 0: obs_type = k break if not set(obs_names).issubset(id_ent): id_ent = id_ent.union(adata_ori.obs_names) adt_obs_names = list(set(obs_names)-set(item)) entity_alias_obs = pd.DataFrame( index=adt_obs_names, columns=['alias'], data=[f'{prefix_i}.{len(item)+x}' for x in range(len(adt_obs_names))]) dict_ent_type[obs_type] = obs_names.union(adt_obs_names) entity_alias = pd.concat( [entity_alias, entity_alias_obs], ignore_index=False) if len(set(var_names).intersection(id_ent)) == 0: prefix_i = f'{prefix}{ctr_ent}' id_ent = id_ent.union(adata_ori.var_names) entity_alias_var = pd.DataFrame( index=var_names, columns=['alias'], data=[f'{prefix_i}.{x}' for x in range(len(var_names))]) settings.pbg_params['entities'][ prefix_i] = {'num_partitions': 1} dict_ent_type[prefix_i] = var_names entity_alias = pd.concat( [entity_alias, entity_alias_var], ignore_index=False) var_type = prefix_i ctr_ent += 1 else: for k, item in dict_ent_type.items(): if len(set(var_names).intersection(item)) > 0: var_type = k break if not set(var_names).issubset(id_ent): id_ent = id_ent.union(adata_ori.var_names) adt_var_names = list(set(var_names)-set(item)) entity_alias_var = pd.DataFrame( index=adt_var_names, columns=['alias'], data=[f'{prefix_i}.{len(item)+x}' for x in range(len(adt_var_names))]) dict_ent_type[var_type] = var_names.union(adt_var_names) entity_alias = pd.concat( [entity_alias, entity_alias_var], ignore_index=False) # generate edges if layer is not None: if layer in adata_ori.layers.keys(): arr_simba = adata_ori.layers[layer] else: print(f'`{layer}` does not exist in adata {ctr_rel} ' 'in `list_adata`.`.X` is being used instead.') arr_simba = adata_ori.X else: arr_simba = adata_ori.X _row, _col = arr_simba.nonzero() df_edges_x = pd.DataFrame(columns=col_names) df_edges_x['source'] = entity_alias.loc[ obs_names[_row], 'alias'].values df_edges_x['relation'] = f'r{ctr_rel}' df_edges_x['destination'] = entity_alias.loc[ var_names[_col], 'alias'].values if add_edge_weights: df_edges_x['weight'] = \ arr_simba[_row, _col].A.flatten() settings.pbg_params['relations'].append({ 'name': f'r{ctr_rel}', 'lhs': f'{obs_type}', 'rhs': f'{var_type}', 'operator': 'none', 'weight': 1.0 }) dict_graph_stats[f'relation{ctr_rel}'] = { 'source': obs_type, 'destination': var_type, 'n_edges': df_edges_x.shape[0]} print( f'relation{ctr_rel}: ' f'source: {obs_type}, ' f'destination: {var_type}\n' f'#edges: {df_edges_x.shape[0]}') df_edges = pd.concat( [df_edges, df_edges_x], ignore_index=True) adata_ori.obs['pbg_id'] = "" adata_ori.var['pbg_id'] = "" adata_ori.obs.loc[obs_names, 'pbg_id'] = \ entity_alias.loc[obs_names, 'alias'].copy() adata_ori.var.loc[var_names, 'pbg_id'] = \ entity_alias.loc[var_names, 'alias'].copy() else: # Collect the indices of entities dict_cells = dict() # unique cell indices from all anndata objects ids_genes = pd.Index([]) ids_peaks = pd.Index([]) ids_kmers = pd.Index([]) ids_motifs = pd.Index([]) if list_CP is not None: for adata_ori in list_CP: if use_top_pcs_CP is None: flag_top_pcs = use_top_pcs else: flag_top_pcs = use_top_pcs_CP if flag_top_pcs: adata = adata_ori[:, adata_ori.var['top_pcs']].copy() else: adata = adata_ori.copy() ids_cells_i = adata.obs.index if len(dict_cells) == 0: dict_cells[prefix_C] = ids_cells_i else: # check if cell indices are included in dict_cells flag_included = False for k in dict_cells.keys(): ids_cells_k = dict_cells[k] if set(ids_cells_i) <= set(ids_cells_k): flag_included = True break if not flag_included: # create a new set of entities # when not all indices are included dict_cells[ f'{prefix_C}{len(dict_cells)+1}'] = \ ids_cells_i ids_peaks = ids_peaks.union(adata.var.index) if list_PM is not None: for adata_ori in list_PM: if use_top_pcs_PM is None: flag_top_pcs = use_top_pcs else: flag_top_pcs = use_top_pcs_PM if flag_top_pcs: adata = adata_ori[:, adata_ori.var['top_pcs']].copy() else: adata = adata_ori.copy() ids_peaks = ids_peaks.union(adata.obs.index) ids_motifs = ids_motifs.union(adata.var.index) if list_PK is not None: for adata_ori in list_PK: if use_top_pcs_PK is None: flag_top_pcs = use_top_pcs else: flag_top_pcs = use_top_pcs_PK if flag_top_pcs: adata = adata_ori[:, adata_ori.var['top_pcs']].copy() else: adata = adata_ori.copy() ids_peaks = ids_peaks.union(adata.obs.index) ids_kmers = ids_kmers.union(adata.var.index) if list_CG is not None: for adata_ori in list_CG: if use_highly_variable: adata = adata_ori[ :, adata_ori.var['highly_variable']].copy() else: adata = adata_ori.copy() ids_cells_i = adata.obs.index if len(dict_cells) == 0: dict_cells[prefix_C] = ids_cells_i else: # check if cell indices are included in dict_cells flag_included = False for k in dict_cells.keys(): ids_cells_k = dict_cells[k] if set(ids_cells_i) <= set(ids_cells_k): flag_included = True break if not flag_included: # create a new set of entities # when not all indices are included dict_cells[ f'{prefix_C}{len(dict_cells)+1}'] = \ ids_cells_i ids_genes = ids_genes.union(adata.var.index) entity_alias = pd.DataFrame(columns=['alias']) dict_df_cells = dict() # unique cell dataframes for k in dict_cells.keys(): dict_df_cells[k] = pd.DataFrame( index=dict_cells[k], columns=['alias'], data=[f'{k}.{x}' for x in range(len(dict_cells[k]))]) settings.pbg_params['entities'][k] = {'num_partitions': 1} entity_alias = pd.concat( [entity_alias, dict_df_cells[k]], ignore_index=False) if len(ids_genes) > 0: df_genes = pd.DataFrame( index=ids_genes, columns=['alias'], data=[f'{prefix_G}.{x}' for x in range(len(ids_genes))]) settings.pbg_params['entities'][prefix_G] = {'num_partitions': 1} entity_alias = pd.concat( [entity_alias, df_genes], ignore_index=False) if len(ids_peaks) > 0: df_peaks = pd.DataFrame( index=ids_peaks, columns=['alias'], data=[f'{prefix_P}.{x}' for x in range(len(ids_peaks))]) settings.pbg_params['entities'][prefix_P] = {'num_partitions': 1} entity_alias = pd.concat( [entity_alias, df_peaks], ignore_index=False) if len(ids_kmers) > 0: df_kmers = pd.DataFrame( index=ids_kmers, columns=['alias'], data=[f'{prefix_K}.{x}' for x in range(len(ids_kmers))]) settings.pbg_params['entities'][prefix_K] = {'num_partitions': 1} entity_alias = pd.concat( [entity_alias, df_kmers], ignore_index=False) if len(ids_motifs) > 0: df_motifs = pd.DataFrame( index=ids_motifs, columns=['alias'], data=[f'{prefix_M}.{x}' for x in range(len(ids_motifs))]) settings.pbg_params['entities'][prefix_M] = {'num_partitions': 1} entity_alias = pd.concat( [entity_alias, df_motifs], ignore_index=False) # generate edges dict_graph_stats = dict() if add_edge_weights: col_names = ["source", "relation", "destination", "weight"] else: col_names = ["source", "relation", "destination"] df_edges = pd.DataFrame(columns=col_names) id_r = 0 settings.pbg_params['relations'] = [] if list_CP is not None: for i, adata_ori in enumerate(list_CP): if use_top_pcs: adata = adata_ori[:, adata_ori.var['top_pcs']].copy() else: adata = adata_ori.copy() # select reference of cells for key, df_cells in dict_df_cells.items(): if set(adata.obs_names) <= set(df_cells.index): break if layer is not None: if layer in adata.layers.keys(): arr_simba = adata.layers[layer] else: print(f'`{layer}` does not exist in anndata {i} ' 'in `list_CP`.`.X` is being used instead.') arr_simba = adata.X else: arr_simba = adata.X _row, _col = arr_simba.nonzero() df_edges_x = pd.DataFrame(columns=col_names) df_edges_x['source'] = df_cells.loc[ adata.obs_names[_row], 'alias'].values df_edges_x['relation'] = f'r{id_r}' df_edges_x['destination'] = df_peaks.loc[ adata.var_names[_col], 'alias'].values if add_edge_weights: df_edges_x['weight'] = \ arr_simba[_row, _col].A.flatten() settings.pbg_params['relations'].append({ 'name': f'r{id_r}', 'lhs': f'{key}', 'rhs': f'{prefix_P}', 'operator': 'none', 'weight': 1.0 }) else: settings.pbg_params['relations'].append({ 'name': f'r{id_r}', 'lhs': f'{key}', 'rhs': f'{prefix_P}', 'operator': 'none', 'weight': 1.0 }) dict_graph_stats[f'relation{id_r}'] = { 'source': key, 'destination': prefix_P, 'n_edges': df_edges_x.shape[0]} print( f'relation{id_r}: ' f'source: {key}, ' f'destination: {prefix_P}\n' f'#edges: {df_edges_x.shape[0]}') id_r += 1 df_edges = pd.concat( [df_edges, df_edges_x], ignore_index=True) adata_ori.obs['pbg_id'] = "" adata_ori.var['pbg_id'] = "" adata_ori.obs.loc[adata.obs_names, 'pbg_id'] = \ df_cells.loc[adata.obs_names, 'alias'].copy() adata_ori.var.loc[adata.var_names, 'pbg_id'] = \ df_peaks.loc[adata.var_names, 'alias'].copy() if list_PM is not None: for i, adata_ori in enumerate(list_PM): if use_top_pcs: adata = adata_ori[:, adata_ori.var['top_pcs']].copy() else: adata = adata_ori.copy() if layer is not None: if layer in adata.layers.keys(): arr_simba = adata.layers[layer] else: print(f'`{layer}` does not exist in anndata {i} ' 'in `list_PM`.`.X` is being used instead.') arr_simba = adata.X else: arr_simba = adata.X _row, _col = arr_simba.nonzero() df_edges_x = pd.DataFrame(columns=col_names) df_edges_x['source'] = df_peaks.loc[ adata.obs_names[_row], 'alias'].values df_edges_x['relation'] = f'r{id_r}' df_edges_x['destination'] = df_motifs.loc[ adata.var_names[_col], 'alias'].values if add_edge_weights: df_edges_x['weight'] = \ arr_simba[_row, _col].A.flatten() settings.pbg_params['relations'].append({ 'name': f'r{id_r}', 'lhs': f'{prefix_P}', 'rhs': f'{prefix_M}', 'operator': 'none', 'weight': 1.0 }) else: settings.pbg_params['relations'].append({ 'name': f'r{id_r}', 'lhs': f'{prefix_P}', 'rhs': f'{prefix_M}', 'operator': 'none', 'weight': 0.2 }) dict_graph_stats[f'relation{id_r}'] = { 'source': prefix_P, 'destination': prefix_M, 'n_edges': df_edges_x.shape[0]} print( f'relation{id_r}: ' f'source: {prefix_P}, ' f'destination: {prefix_M}\n' f'#edges: {df_edges_x.shape[0]}') id_r += 1 df_edges = pd.concat( [df_edges, df_edges_x], ignore_index=True) adata_ori.obs['pbg_id'] = "" adata_ori.var['pbg_id'] = "" adata_ori.obs.loc[adata.obs_names, 'pbg_id'] = \ df_peaks.loc[adata.obs_names, 'alias'].copy() adata_ori.var.loc[adata.var_names, 'pbg_id'] = \ df_motifs.loc[adata.var_names, 'alias'].copy() if list_PK is not None: for i, adata_ori in enumerate(list_PK): if use_top_pcs: adata = adata_ori[:, adata_ori.var['top_pcs']].copy() else: adata = adata_ori.copy() if layer is not None: if layer in adata.layers.keys(): arr_simba = adata.layers[layer] else: print(f'`{layer}` does not exist in anndata {i} ' 'in `list_PK`.`.X` is being used instead.') arr_simba = adata.X else: arr_simba = adata.X _row, _col = arr_simba.nonzero() df_edges_x = pd.DataFrame(columns=col_names) df_edges_x['source'] = df_peaks.loc[ adata.obs_names[_row], 'alias'].values df_edges_x['relation'] = f'r{id_r}' df_edges_x['destination'] = df_kmers.loc[ adata.var_names[_col], 'alias'].values if add_edge_weights: df_edges_x['weight'] = \ arr_simba[_row, _col].A.flatten() settings.pbg_params['relations'].append({ 'name': f'r{id_r}', 'lhs': f'{prefix_P}', 'rhs': f'{prefix_K}', 'operator': 'none', 'weight': 1 }) else: settings.pbg_params['relations'].append({ 'name': f'r{id_r}', 'lhs': f'{prefix_P}', 'rhs': f'{prefix_K}', 'operator': 'none', 'weight': 0.02 }) print( f'relation{id_r}: ' f'source: {prefix_P}, ' f'destination: {prefix_K}\n' f'#edges: {df_edges_x.shape[0]}') dict_graph_stats[f'relation{id_r}'] = { 'source': prefix_P, 'destination': prefix_K, 'n_edges': df_edges_x.shape[0]} id_r += 1 df_edges = pd.concat( [df_edges, df_edges_x], ignore_index=True) adata_ori.obs['pbg_id'] = "" adata_ori.var['pbg_id'] = "" adata_ori.obs.loc[adata.obs_names, 'pbg_id'] = \ df_peaks.loc[adata.obs_names, 'alias'].copy() adata_ori.var.loc[adata.var_names, 'pbg_id'] = \ df_kmers.loc[adata.var_names, 'alias'].copy() if list_CG is not None: for i, adata_ori in enumerate(list_CG): if use_highly_variable: adata = adata_ori[ :, adata_ori.var['highly_variable']].copy() else: adata = adata_ori.copy() # select reference of cells for key, df_cells in dict_df_cells.items(): if set(adata.obs_names) <= set(df_cells.index): break if layer is not None: if layer in adata.layers.keys(): arr_simba = adata.layers[layer] else: print(f'`{layer}` does not exist in anndata {i} ' 'in `list_CG`.`.X` is being used instead.') arr_simba = adata.X else: arr_simba = adata.X if add_edge_weights: _row, _col = arr_simba.nonzero() df_edges_x = pd.DataFrame(columns=col_names) df_edges_x['source'] = df_cells.loc[ adata.obs_names[_row], 'alias'].values df_edges_x['relation'] = f'r{id_r}' df_edges_x['destination'] = df_genes.loc[ adata.var_names[_col], 'alias'].values df_edges_x['weight'] = \ arr_simba[_row, _col].A.flatten() settings.pbg_params['relations'].append({ 'name': f'r{id_r}', 'lhs': f'{key}', 'rhs': f'{prefix_G}', 'operator': 'none', 'weight': 1.0, }) print( f'relation{id_r}: ' f'source: {key}, ' f'destination: {prefix_G}\n' f'#edges: {df_edges_x.shape[0]}') dict_graph_stats[f'relation{id_r}'] = { 'source': key, 'destination': prefix_G, 'n_edges': df_edges_x.shape[0]} id_r += 1 df_edges = pd.concat( [df_edges, df_edges_x], ignore_index=True) else: expr_level = np.unique(arr_simba.data) expr_weight = np.linspace( start=1, stop=5, num=len(expr_level)) for i_lvl, lvl in enumerate(expr_level): _row, _col = (arr_simba == lvl).astype(int).nonzero() df_edges_x = pd.DataFrame(columns=col_names) df_edges_x['source'] = df_cells.loc[ adata.obs_names[_row], 'alias'].values df_edges_x['relation'] = f'r{id_r}' df_edges_x['destination'] = df_genes.loc[ adata.var_names[_col], 'alias'].values settings.pbg_params['relations'].append({ 'name': f'r{id_r}', 'lhs': f'{key}', 'rhs': f'{prefix_G}', 'operator': 'none', 'weight': round(expr_weight[i_lvl], 2), }) print( f'relation{id_r}: ' f'source: {key}, ' f'destination: {prefix_G}\n' f'#edges: {df_edges_x.shape[0]}') dict_graph_stats[f'relation{id_r}'] = { 'source': key, 'destination': prefix_G, 'n_edges': df_edges_x.shape[0]} id_r += 1 df_edges = pd.concat( [df_edges, df_edges_x], ignore_index=True) adata_ori.obs['pbg_id'] = "" adata_ori.var['pbg_id'] = "" adata_ori.obs.loc[adata.obs_names, 'pbg_id'] = \ df_cells.loc[adata.obs_names, 'alias'].copy() adata_ori.var.loc[adata.var_names, 'pbg_id'] = \ df_genes.loc[adata.var_names, 'alias'].copy() if list_CC is not None: for i, adata in enumerate(list_CC): # select reference of cells for key_obs, df_cells_obs in dict_df_cells.items(): if set(adata.obs_names) <= set(df_cells_obs.index): break for key_var, df_cells_var in dict_df_cells.items(): if set(adata.var_names) <= set(df_cells_var.index): break if layer is not None: if layer in adata.layers.keys(): arr_simba = adata.layers[layer] else: print(f'`{layer}` does not exist in anndata {i} ' 'in `list_PM`.`.X` is being used instead.') arr_simba = adata.X else: arr_simba = adata.X _row, _col = arr_simba.nonzero() # edges between ref and query df_edges_x = pd.DataFrame(columns=col_names) df_edges_x['source'] = df_cells_obs.loc[ adata.obs_names[_row], 'alias'].values df_edges_x['relation'] = f'r{id_r}' df_edges_x['destination'] = df_cells_var.loc[ adata.var_names[_col], 'alias'].values if add_edge_weights: df_edges_x['weight'] = \ arr_simba[_row, _col].A.flatten() settings.pbg_params['relations'].append({ 'name': f'r{id_r}', 'lhs': f'{key_obs}', 'rhs': f'{key_var}', 'operator': 'none', 'weight': 1.0 }) else: settings.pbg_params['relations'].append({ 'name': f'r{id_r}', 'lhs': f'{key_obs}', 'rhs': f'{key_var}', 'operator': 'none', 'weight': 10.0 }) print( f'relation{id_r}: ' f'source: {key_obs}, ' f'destination: {key_var}\n' f'#edges: {df_edges_x.shape[0]}') dict_graph_stats[f'relation{id_r}'] = { 'source': key_obs, 'destination': key_var, 'n_edges': df_edges_x.shape[0]} id_r += 1 df_edges = pd.concat( [df_edges, df_edges_x], ignore_index=True) adata.obs['pbg_id'] = df_cells_obs.loc[ adata.obs_names, 'alias'].copy() adata.var['pbg_id'] = df_cells_var.loc[ adata.var_names, 'alias'].copy() print(f'Total number of edges: {df_edges.shape[0]}') dict_graph_stats['n_edges'] = df_edges.shape[0] settings.graph_stats[dirname] = dict_graph_stats print(f'Writing graph file "pbg_graph.txt" to "{filepath}" ...') df_edges.to_csv(os.path.join(filepath, "pbg_graph.txt"), header=False, index=False, sep='\t') entity_alias.to_csv(os.path.join(filepath, 'entity_alias.txt'), header=True, index=True, sep='\t') with open(os.path.join(filepath, 'graph_stats.json'), 'w') as fp: json.dump(dict_graph_stats, fp, sort_keys=True, indent=4, separators=(',', ': ')) print("Finished.") if copy: return df_edges else: return None
[docs] def pbg_train(dirname=None, pbg_params=None, output='model', auto_wd=True, save_wd=False, use_edge_weights=False): """PBG training Parameters ---------- dirname: `str`, optional (default: None) The name of the directory in which graph is stored If None, it will be inferred from `pbg_params['entity_path']` pbg_params: `dict`, optional (default: None) Configuration for pbg training. If specified, it will be used instead of the default setting output: `str`, optional (default: 'model') The name of the directory where training output will be written to. It overrides `pbg_params` if `checkpoint_path` is specified in it auto_wd: `bool`, optional (default: True) If True, it will override `pbg_params['wd']` with a new weight decay estimated based on training sample size Recommended for relative small training sample size (<1e7) save_wd: `bool`, optional (default: False) If True, estimated `wd` will be saved to `settings.pbg_params['wd']` use_edge_weights: `bool`, optional (default: False) If True, the edge weights are used for the training; If False, the weights of relation types are used instead, and edge weights will be ignored. Returns ------- updates `settings.pbg_params` with the following parameter checkpoint_path: The path to the directory where checkpoints (and thus the output) will be written to. If checkpoints are found in it, training will resume from them. """ if pbg_params is None: pbg_params = settings.pbg_params.copy() else: assert isinstance(pbg_params, dict),\ "`pbg_params` must be dict" if dirname is None: filepath = Path(pbg_params['entity_path']).parent.parent.as_posix() else: filepath = os.path.join(settings.workdir, 'pbg', dirname) pbg_params['checkpoint_path'] = os.path.join(filepath, output) settings.pbg_params['checkpoint_path'] = pbg_params['checkpoint_path'] if auto_wd: print('Auto-estimating weight decay ...') # empirical numbers from simulation experiments if settings.graph_stats[ os.path.basename(filepath)]['n_edges'] < 5e7: # optimial wd (0.013) for sample size (2725781) wd = np.around( 0.013 * 2725781 / settings.graph_stats[ os.path.basename(filepath)]['n_edges'], decimals=6) else: # optimial wd (0.0004) for sample size (59103481) wd = np.around( 0.0004 * 59103481 / settings.graph_stats[ os.path.basename(filepath)]['n_edges'], decimals=6) pbg_params['wd'] = wd if save_wd: settings.pbg_params['wd'] = pbg_params['wd'] print(f"`.settings.pbg_params['wd']` has been updated to {wd}") print(f'Weight decay being used for training is {pbg_params["wd"]}') # to avoid oversubscription issues in workloads # that involve nested parallelism os.environ["OMP_NUM_THREADS"] = "1" loader = ConfigFileLoader() config = loader.load_config_simba(pbg_params) set_logging_verbosity(config.verbose) list_filenames = [os.path.join(filepath, "pbg_graph.txt")] input_edge_paths = [Path(name) for name in list_filenames] print("Converting input data ...") if use_edge_weights: print("Edge weights are being used ...") convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, TSVEdgelistReader(lhs_col=0, rhs_col=2, rel_col=1, weight_col=3), dynamic_relations=config.dynamic_relations, ) else: convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, TSVEdgelistReader(lhs_col=0, rhs_col=2, rel_col=1), dynamic_relations=config.dynamic_relations, ) subprocess_init = SubprocessInitializer() subprocess_init.register(setup_logging, config.verbose) subprocess_init.register(add_to_sys_path, loader.config_dir.name) train_config = attr.evolve(config, edge_paths=config.edge_paths) print("Starting training ...") train(train_config, subprocess_init=subprocess_init) print("Finished")