Source code for seqgra.seqgra

#!/usr/bin/env python

"""MIT - CSAIL - Gifford Lab - seqgra

seqgra complete pipeline:
1. generate data based on data definition (once), see run_simulator.py
2. train model on data (once), see run_learner.py
3. evaluate model performance with SIS, see run_sis.py

@author: Konstantin Krismer
"""
import argparse
import logging
import os
import shutil
from typing import List, Optional

import seqgra
import seqgra.constants as c
from seqgra import MiscHelper
from seqgra.evaluator import Evaluator
from seqgra.evaluator import FeatureImportanceEvaluator
from seqgra.idresolver import IdResolver
from seqgra.learner import Learner
from seqgra.model import DataDefinition
from seqgra.model import ModelDefinition
from seqgra.parser import DataDefinitionParser
from seqgra.parser import XMLDataDefinitionParser
from seqgra.parser import ModelDefinitionParser
from seqgra.parser import XMLModelDefinitionParser
from seqgra.simulator import Simulator
from seqgra.simulator.heatmap import GrammarPositionHeatmap


[docs]def run_seqgra(data_def_file: Optional[str], data_folder: Optional[str], model_def_file: Optional[str], evaluator_ids: Optional[List[str]], output_dir: str, in_memory: bool, print_info: bool, silent: bool, remove_existing_data: bool, gpu_id: int, no_checks: bool, eval_sets: Optional[List[str]], eval_n: Optional[int], eval_n_per_label: Optional[int], eval_suppress_plots: Optional[bool], eval_fi_predict_threshold: Optional[float], eval_sis_predict_threshold: Optional[float], eval_grad_importance_threshold: Optional[float]) -> None: logger = logging.getLogger(__name__) if silent: logger.setLevel(os.environ.get("LOGLEVEL", "WARNING")) output_dir = MiscHelper.format_output_dir(output_dir.strip()) new_data: bool = False new_model: bool = False if data_def_file is None: data_definition: Optional[DataDefinition] = None grammar_id = data_folder.strip() logger.info("loaded experimental data") GrammarPositionHeatmap.create(output_dir + "input/" + grammar_id, c.DataSet.TRAINING) GrammarPositionHeatmap.create(output_dir + "input/" + grammar_id, c.DataSet.VALIDATION) GrammarPositionHeatmap.create(output_dir + "input/" + grammar_id, c.DataSet.TEST) else: # generate synthetic data data_config = MiscHelper.read_config_file(data_def_file) data_def_parser: DataDefinitionParser = XMLDataDefinitionParser( data_config, silent) data_definition: DataDefinition = data_def_parser.get_data_definition() grammar_id: str = data_definition.grammar_id if print_info: print(data_definition) if remove_existing_data: simulator_output_dir: str = output_dir + "input/" + grammar_id if os.path.exists(simulator_output_dir): shutil.rmtree(simulator_output_dir, ignore_errors=True) logger.info("removed existing synthetic data") simulator = Simulator(data_definition, output_dir + "input", silent) synthetic_data_available: bool = \ len(os.listdir(simulator.output_dir)) > 0 if synthetic_data_available: logger.info("loaded previously generated synthetic data") else: logger.info("generating synthetic data") simulator.simulate_data() new_data = True simulator.create_grammar_heatmap(c.DataSet.TRAINING) simulator.create_grammar_heatmap(c.DataSet.VALIDATION) simulator.create_grammar_heatmap(c.DataSet.TEST) simulator.create_motif_info() simulator.create_motif_kl_divergence_matrix() simulator.create_empirical_similarity_score_matrix() # get learner if model_def_file is not None: model_config = MiscHelper.read_config_file(model_def_file) model_def_parser: ModelDefinitionParser = XMLModelDefinitionParser( model_config, silent) model_definition: ModelDefinition = \ model_def_parser.get_model_definition() if print_info: print(model_definition) if remove_existing_data: learner_output_dir: str = output_dir + "models/" + grammar_id + \ "/" + model_definition.model_id if os.path.exists(learner_output_dir): shutil.rmtree(learner_output_dir, ignore_errors=True) logger.info("removed pretrained model") learner: Learner = IdResolver.get_learner( model_definition, data_definition, output_dir + "input/" + grammar_id, output_dir + "models/" + grammar_id, not no_checks, gpu_id, silent) # train model on data trained_model_available: bool = len(os.listdir(learner.output_dir)) > 0 train_model: bool = not trained_model_available if trained_model_available: try: learner.load_model() logger.info("loaded previously trained model") except Exception as exception: logger.warning("unable to load previously trained model; " "previously trained model will be deleted " "and new model will be trained; " "exception caught: %s", str(exception)) # delete all files and folders in learner output directory for file_name in os.listdir(learner.output_dir): file_path = os.path.join(learner.output_dir, file_name) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except OSError as e: logger.warning( "Failed to delete %s. Reason: %s", file_path, e) train_model = True if new_data and train_model: raise Exception("previously trained model used outdated " "training data; delete '" + learner.output_dir + "' and run seqgra again to train new model " "on current data") if train_model: logger.info("training model") learner.create_model() if print_info: learner.print_model_summary() if in_memory: training_set_file: str = learner.get_examples_file( c.DataSet.TRAINING) validation_set_file: str = learner.get_examples_file( c.DataSet.VALIDATION) x_train, y_train = learner.parse_examples_data( training_set_file) x_val, y_val = learner.parse_examples_data(validation_set_file) learner.train_model(x_train=x_train, y_train=y_train, x_val=x_val, y_val=y_val) else: learner.train_model() learner.save_model() new_model = learner.definition.library != \ c.LibraryType.BAYES_OPTIMAL_CLASSIFIER if remove_existing_data: evaluator_output_dir: str = output_dir + "evaluation/" + \ grammar_id + "/" + model_definition.model_id if os.path.exists(evaluator_output_dir): shutil.rmtree(evaluator_output_dir, ignore_errors=True) logger.info("removed evaluator results") if evaluator_ids: logger.info("evaluating model using interpretability methods") if eval_sets: for eval_set in eval_sets: if not eval_set in c.DataSet.ALL_SETS: raise Exception( "invalid set selected for evaluation: " + eval_set + "; only the following sets are " "allowed: " + ", ".join(c.DataSet.ALL_SETS)) else: eval_sets: List[str] = c.DataSet.ALL_SETS evaluation_dir: str = output_dir + "evaluation/" + \ grammar_id + "/" + learner.definition.model_id for evaluator_id in evaluator_ids: results_dir: str = evaluation_dir + "/" + evaluator_id results_exist: bool = os.path.exists(results_dir) and \ len(os.listdir(results_dir)) > 0 if results_exist: logger.info("skip evaluator %s: results already saved " "to disk", evaluator_id) if new_model: logger.warning("results from evaluator %s are based " "on an outdated model; delete " "'%s' and run seqgra again to get " "results from %s on current model", evaluator_id, results_dir, evaluator_id) else: evaluator: Evaluator = IdResolver.get_evaluator( evaluator_id, learner, evaluation_dir, eval_sis_predict_threshold, eval_grad_importance_threshold, silent) if eval_n_per_label: eval_n = eval_n_per_label for eval_set in eval_sets: learner.set_seed() is_fi_evaluator: bool = isinstance( evaluator, FeatureImportanceEvaluator) if is_fi_evaluator: logger.info("running feature importance " "evaluator %s on %s set", evaluator_id, eval_set) else: eval_fi_predict_threshold = None logger.info("running evaluator %s on %s set", evaluator_id, eval_set) evaluator.evaluate_model( eval_set, subset_n=eval_n, subset_n_per_label=eval_n_per_label is not None, subset_threshold=eval_fi_predict_threshold, suppress_plots=eval_suppress_plots) else: logger.info("skipping evaluation step: no evaluator specified")
[docs]def create_parser(): parser = argparse.ArgumentParser( prog="seqgra", description="Generate synthetic data based on grammar, train model on " "synthetic data, evaluate model") parser.add_argument( "-v", "--version", action="version", version="%(prog)s " + seqgra.__version__) group = parser.add_mutually_exclusive_group(required=True) group.add_argument( "-d", "--data-def-file", type=str, help="path to the segra XML data definition file. Use this option " "to generate synthetic data based on a seqgra grammar (specify " "either -d or -f, not both)" ) group.add_argument( "-f", "--data-folder", type=str, help="experimental data folder name inside outputdir/input. Use this " "option to train the model on experimental or externally synthesized " "data (specify either -f or -d, not both)" ) parser.add_argument( "-m", "--model-def-file", type=str, help="path to the seqgra XML model definition file" ) parser.add_argument( "-e", "--evaluators", type=str, default=None, nargs="+", help="evaluator ID or IDs: IDs of " "conventional evaluators include " + ", ".join(sorted(c.EvaluatorID.CONVENTIONAL_EVALUATORS)) + "; IDs of feature importance evaluators include " + ", ".join(sorted(c.EvaluatorID.FEATURE_IMPORTANCE_EVALUATORS)) ) parser.add_argument( "-o", "--output-dir", type=str, required=True, help="output directory, subdirectories are created for generated " "data, trained model, and model evaluation" ) parser.add_argument( "-i", "--in-memory", action="store_true", help="if this flag is set, training and validation data will be " "stored in-memory instead of loaded in chunks" ) parser.add_argument( "-p", "--print", action="store_true", help="if this flag is set, data definition, model definition, and " "model summary are printed" ) parser.add_argument( "-s", "--silent", action="store_true", help="if this flag is set, only warnings and errors are printed" ) parser.add_argument( "-r", "--remove", action="store_true", help="if this flag is set, previously stored data for this grammar - " "model combination will be removed prior to the analysis run. This " "includes the folders input/[grammar ID], " "models/[grammar ID]/[model ID], and " "evaluation/[grammar ID]/[model ID]." ) parser.add_argument( "-g", "--gpu", type=int, default=0, help="ID of GPU used by TensorFlow and PyTorch (defaults to GPU " "ID 0); CPU is used if no GPU is available or GPU ID is set to -1" ) parser.add_argument( "--no-checks", action="store_true", help="if this flag is set, examples and example annotations will not " "be validated before training, e.g., that DNA sequences only contain " "A, C, G, T, N" ) parser.add_argument( "--eval-sets", type=str, default=[c.DataSet.TEST], nargs="+", help="either one or more of the following: training, validation, " "test; selects data set for evaluation; this evaluator argument " "will be passed to all evaluators" ) parser.add_argument( "--eval-n", type=int, help="maximum number of examples to be evaluated per set (defaults " "to the total number of examples); this evaluator argument " "will be passed to all evaluators" ) parser.add_argument( "--eval-n-per-label", type=int, help="maximum number of examples to be evaluated for each label and " "set (defaults to the total number of examples unless eval-n is set, " "overrules eval-n); " "this evaluator argument will be passed to all evaluators" ) parser.add_argument( "--eval-suppress-plots", action="store_true", help="if this flag is set, plots are suppressed globally; " "this evaluator argument will be passed to all evaluators" ) parser.add_argument( "--eval-fi-predict-threshold", type=float, default=0.5, help="prediction threshold used to select examples for evaluation, " "only examples with predict(x) > threshold will be passed on to " "evaluators (defaults to 0.5); " "this evaluator argument will be passed to feature importance " "evaluators only" ) parser.add_argument( "--eval-sis-predict-threshold", type=float, default=0.5, help="prediction threshold for Sufficient Input Subsets; " "this evaluator argument is only visible to the SIS evaluator" ) parser.add_argument( "--eval-grad-importance-threshold", type=float, default=0.01, help="feature importance threshold for gradient-based feature " "importance evaluators; this parameter only affects thresholded " "grammar agreement plots, not the feature importance measures " "themselves; this evaluator argument is only visible to " "gradient-based feature importance evaluators (defaults to 0.01)" ) return parser
[docs]def main(): logging.basicConfig(level=logging.INFO) parser = create_parser() args = parser.parse_args() if args.data_folder and args.model_def_file is None: parser.error("-f/--data-folder requires -m/--model-def-file.") if args.evaluators and args.model_def_file is None: parser.error("-e/--evaluators requires -m/--model-def-file.") if args.evaluators is not None: for evaluator in args.evaluators: if evaluator not in c.EvaluatorID.ALL_EVALUATOR_IDS: raise ValueError( "invalid evaluator ID {s!r}".format(s=evaluator)) run_seqgra(args.data_def_file, args.data_folder, args.model_def_file, args.evaluators, args.output_dir, args.in_memory, args.print, args.silent, args.remove, args.gpu, args.no_checks, args.eval_sets, args.eval_n, args.eval_n_per_label, args.eval_suppress_plots, args.eval_fi_predict_threshold, args.eval_sis_predict_threshold, args.eval_grad_importance_threshold)
if __name__ == "__main__": main()