Skip to contents

This function trains an XGBoost classifier to distinguish a specified case group from others using gene expression data from a Seurat object. SHAP (SHapley Additive exPlanations) values are computed to quantify the contribution of each gene to the model predictions. The top SHAP-ranked genes are returned as key driver candidates.

Usage

FindShapKeyDriver(
  seurat_obj,
  conditions,
  set_case,
  top_n = 50,
  max_depth = 4,
  eta = 0.1,
  nrounds = 50,
  nthread = 2,
  variable_genes = NULL,
  out_dir = NULL,
  mode = "full",
  train_fraction = 0.8,
  nfold = 5,
  seed = 123
)

Arguments

seurat_obj

A Seurat object with normalized gene expression and metadata.

conditions

Column name in seurat_obj@meta.data indicating the condition labels.

set_case

Character. Case label to compare against all other values in conditions.

top_n

Integer. Number of top SHAP-ranked genes to return (default: 50).

max_depth

Integer. Maximum depth of each XGBoost tree (default: 4).

eta

Numeric. Learning rate (shrinkage) for XGBoost (default: 0.1).

nrounds

Integer. Number of boosting rounds for XGBoost (default: 50).

nthread

Integer. Number of CPU threads to use for XGBoost (default: 2).

variable_genes

Optional character vector of genes to use. If NULL, uses VariableFeatures(seurat_obj).

out_dir

Optional directory to save results. If NULL, results are not saved to disk.

mode

Character. Training mode: "full", "split", or "cv" (default: "full").

train_fraction

Numeric. Fraction of cells used for training in "split" mode (default: 0.8).

nfold

Integer. Number of folds for cross-validation (default: 5).

seed

Integer. Random seed for reproducibility (default: 123).

Value

A modified Seurat object with SHAP results stored in @misc$shap, including:

  • model: Trained XGBoost model (only in "full" and "split" modes)

  • shap_result: Raw SHAP decomposition object (only in "full" and "split" modes)

  • shap_long: Long-format SHAP values (per-cell, per-gene)

  • shap_summary: Mean absolute SHAP values per gene

  • key_drivers: Top top_n SHAP-ranked genes

  • variable_genes: Genes used in the model

  • test_auc: AUC on held-out test set (only in "split" mode)

  • auc_per_fold: Vector of AUC scores per fold (only in "cv" mode)

  • mean_auc: Mean AUC across folds (only in "cv" mode)

  • mode: Training mode used

Details

Three model training modes are supported via the mode parameter:

  • "full": Train on the entire dataset (no test set, no cross-validation).

  • "split": Randomly split data into training and testing subsets.

  • "cv": Perform k-fold cross-validation.

In "split" mode, test set AUC is computed and stored. In "cv" mode, per-fold AUCs and mean AUC are computed. In "full" mode, AUC is not computed to avoid overestimation.

SHAP results and metadata are stored in seurat_obj@misc$shap. Optionally, results are written to disk under out_dir, with all files documented in an info.txt.

Output Files (if out_dir is provided)

  • model.txt: XGBoost model summary (not in "cv" mode)

  • shap_result.rds: SHAP decomposition object (not in "cv" mode)

  • shap_summary.txt: Mean absolute SHAP value per gene

  • variable_genes.txt: List of genes used for model training

  • key_drivers.txt: Top SHAP-ranked genes

  • shap_long.rds: Full SHAP values in long format

  • auc.txt: AUC from held-out test set (only in "split" mode)

  • auc_per_fold.txt: AUC per fold (only in "cv" mode)

  • info.txt: Summary of saved outputs

Examples

# Run SHAP analysis in full mode
seurat_obj <- FindShapKeyDriver(seurat_obj, conditions = "Diagnosis", set_case = "AD", top_n = 30)
#> Loading required package: xgboost
#> Error: object 'seurat_obj' not found

# Run in split mode with custom output directory
seurat_obj <- FindShapKeyDriver(seurat_obj, conditions = "Diagnosis", set_case = "AD",
                                mode = "split", out_dir = "results/shap/")
#> Error: object 'seurat_obj' not found

# Run in 5-fold CV mode
seurat_obj <- FindShapKeyDriver(seurat_obj, conditions = "Diagnosis", set_case = "AD",
                                mode = "cv", nfold = 5)
#> Error: object 'seurat_obj' not found