Source code for seqgra.learner.torch.torchhelper

"""MIT - CSAIL - Gifford Lab - seqgra

PyTorch learner helper class

@author: Konstantin Krismer
"""
from ast import literal_eval
from distutils.util import strtobool
import importlib
import logging
import os
import random
import shutil
import sys
from typing import FrozenSet, List, Optional

from ignite.engine import Events
from ignite.engine import create_supervised_trainer
from ignite.engine import create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import EarlyStopping, ModelCheckpoint
import numpy as np
import pkg_resources
import torch

from seqgra import ModelSize
import seqgra.constants as c
from seqgra.learner import Learner


[docs]class TorchHelper: MULTI_CLASS_CLASSIFICATION_LOSSES: FrozenSet[str] = frozenset( ["crossentropyloss", "nllloss", "kldivloss", "hingeembeddingloss", "cosineembeddingloss"]) MULTI_LABEL_CLASSIFICATION_LOSSES: FrozenSet[str] = frozenset( ["bcewithlogitsloss", "bceloss"]) MULTIPLE_REGRESSION_LOSSES: FrozenSet[str] = frozenset( ["l1loss", "mseloss", "smoothl1loss"]) MULTIVARIATE_REGRESSION_LOSSES: FrozenSet[str] = MULTIPLE_REGRESSION_LOSSES
[docs] @staticmethod def create_model(learner: Learner) -> None: path = learner.definition.architecture.external_model_path class_name = learner.definition.architecture.external_model_class_name learner.set_seed() if path is None: raise Exception("embedded architecture definition not supported" " for PyTorch models") elif path is not None and \ learner.definition.architecture.external_model_format is not None: if learner.definition.architecture.external_model_format == "pytorch-module": if os.path.isfile(path): if class_name is None: raise Exception( "PyTorch model class name not specified") else: module_spec = importlib.util.spec_from_file_location( "model", path) torch_model_module = importlib.util.module_from_spec( module_spec) module_spec.loader.exec_module(torch_model_module) torch_model_class = getattr(torch_model_module, class_name) learner.model = torch_model_class() else: raise Exception( "PyTorch model class file does not exist: " + path) else: raise Exception( "unsupported PyTorch model format: " + learner.definition.architecture.external_model_format) else: raise Exception("neither internal nor external architecture " "definition provided") if learner.definition.optimizer_hyperparameters is None: raise Exception("optimizer undefined") else: learner.optimizer = TorchHelper.get_optimizer( learner.definition.optimizer_hyperparameters, learner.model.parameters()) if learner.definition.loss_hyperparameters is None: raise Exception("loss undefined") else: learner.criterion = TorchHelper.get_loss( learner.definition.loss_hyperparameters) if learner.metrics is None: raise Exception("metrics undefined")
[docs] @staticmethod def print_model_summary(learner: Learner): if learner.model: print(learner.model) else: print("uninitialized model")
[docs] @staticmethod def set_seed(learner: Learner) -> None: random.seed(learner.definition.seed) np.random.seed(learner.definition.seed) torch.manual_seed(learner.definition.seed)
[docs] @staticmethod def train_model( learner: Learner, training_dataset: torch.utils.data.Dataset, validation_dataset: torch.utils.data.Dataset, output_layer_activation_function: Optional[str] = None, silent: bool = False) -> None: logger = logging.getLogger(__name__) if silent: logger.setLevel(os.environ.get("LOGLEVEL", "WARNING")) if learner.model is None: learner.create_model() # save number of model parameters num_trainable_params, num_non_trainable_params = learner.get_num_params() with open(learner.output_dir + "num-model-parameters.txt", "w") as model_param_file: model_param_file.write("number of trainable parameters\t" + str(num_trainable_params) + "\n") model_param_file.write("number of non-trainable parameters\t" + str(num_non_trainable_params) + "\n") model_param_file.write("number of all parameters\t" + str(num_trainable_params + num_non_trainable_params) + "\n") batch_size = int( learner.definition.training_process_hyperparameters["batch_size"]) # init data loaders if isinstance(training_dataset, torch.utils.data.IterableDataset): # examples are shuffled in IterableDataSet class shuffle: bool = False else: shuffle: bool = bool( strtobool(learner.definition.training_process_hyperparameters["shuffle"])) training_loader = torch.utils.data.DataLoader( training_dataset, batch_size=batch_size, shuffle=shuffle) validation_loader = torch.utils.data.DataLoader( validation_dataset, batch_size=batch_size, shuffle=False) # GPU or CPU? learner.model = learner.model.to(learner.device) logger.info("using device: %s", learner.device_label) # training loop trainer = create_supervised_trainer(learner.model, learner.optimizer, learner.criterion, device=learner.device) train_evaluator = create_supervised_evaluator( learner.model, metrics=TorchHelper.get_metrics( learner, output_layer_activation_function), device=learner.device) val_evaluator = create_supervised_evaluator( learner.model, metrics=TorchHelper.get_metrics( learner, output_layer_activation_function), device=learner.device) logging.getLogger("ignite.engine.engine.Engine").setLevel( logging.WARNING) num_epochs: int = int( learner.definition.training_process_hyperparameters["epochs"]) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(trainer): logger.info("epoch {}/{}".format(trainer.state.epoch, num_epochs)) train_evaluator.run(training_loader) metrics = train_evaluator.state.metrics logger.info(TorchHelper._format_metrics_output( metrics, "training set")) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(trainer): val_evaluator.run(validation_loader) metrics = val_evaluator.state.metrics logger.info(TorchHelper._format_metrics_output(metrics, "validation set")) @trainer.on(Events.EPOCH_COMPLETED) def log_last_epoch(trainer): with open(learner.output_dir + "last-epoch-completed.txt", "w") as last_epoch_file: last_epoch_file.write(str(trainer.state.epoch) + "\n") # save best model def score_fn(engine): if "loss" in learner.metrics: score = engine.state.metrics["loss"] score = -score elif "accuracy" in learner.metrics: score = engine.state.metrics["accuracy"] else: raise Exception("no metric to track performance") return score best_model_dir: str = learner.output_dir + "tmp" best_model_saver_handler = ModelCheckpoint( best_model_dir, score_function=score_fn, filename_prefix="best", n_saved=1, create_dir=True) val_evaluator.add_event_handler(Events.COMPLETED, best_model_saver_handler, {"model": learner.model}) # early stopping callback if bool(strtobool(learner.definition.training_process_hyperparameters["early_stopping"])): if "early_stopping_patience" in learner.definition.training_process_hyperparameters: patience: int = int( learner.definition.training_process_hyperparameters["early_stopping_patience"]) else: patience: int = 10 es_handler = EarlyStopping(patience=patience, score_function=score_fn, trainer=trainer, min_delta=0) val_evaluator.add_event_handler(Events.COMPLETED, es_handler) trainer.run(training_loader, max_epochs=num_epochs) # load best model after training best_model = TorchHelper.get_best_model_file_name(best_model_dir) if best_model: learner.load_model("tmp/" + best_model) # remove temp folder shutil.rmtree(best_model_dir) else: logger.warn("best model could not be loaded")
[docs] @staticmethod def evaluate_model(learner: Learner, dataset: torch.utils.data.Dataset, output_layer_activation_function: Optional[str] = None): data_loader = torch.utils.data.DataLoader( dataset, batch_size=int( learner.definition.training_process_hyperparameters["batch_size"]), shuffle=False) learner.model = learner.model.to(learner.device) running_loss: float = 0.0 running_correct: int = 0 num_examples: int = 0 learner.model.eval() with torch.no_grad(): for x, y in data_loader: # transfer to device x = x.to(learner.device) y = y.to(learner.device) y_hat = learner.model(x) loss = learner.criterion(y_hat, y) if output_layer_activation_function is not None: if output_layer_activation_function == "softmax": y_hat = torch.nn.functional.softmax(y_hat, dim=1) elif output_layer_activation_function == "sigmoid": y_hat = torch.sigmoid(y_hat) # binarize y_hat if learner.definition.task == c.TaskType.MULTI_CLASS_CLASSIFICATION: indices = torch.argmax(y_hat, dim=1) correct = torch.eq(indices, y).view(-1) elif learner.definition.task == c.TaskType.MULTI_LABEL_CLASSIFICATION: y_hat = torch.gt(y_hat, 0.5) y = y.type_as(y_hat) correct = torch.all(y == y_hat, dim=-1) running_correct += torch.sum(correct).item() running_loss += loss.item() * x.size(0) num_examples += correct.shape[0] overall_loss = running_loss / num_examples overall_accuracy = running_correct / num_examples return {"loss": overall_loss, "accuracy": overall_accuracy}
[docs] @staticmethod def predict(learner: Learner, dataset: torch.utils.data.Dataset, output_layer_activation_function: Optional[str] = None): """ This is the forward calculation from x to y Returns: softmax_linear: Output tensor with the computed logits. """ data_loader = torch.utils.data.DataLoader( dataset, batch_size=int( learner.definition.training_process_hyperparameters["batch_size"]), shuffle=False) learner.model = learner.model.to(learner.device) y_hat = [] learner.model.eval() with torch.no_grad(): for x in data_loader: # transfer to device x = x.to(learner.device) raw_logits = learner.model(x) if output_layer_activation_function is None: y_hat += raw_logits.tolist() elif output_layer_activation_function == "softmax": y_hat += \ torch.nn.functional.softmax(raw_logits, dim=1).tolist() elif output_layer_activation_function == "sigmoid": y_hat += torch.sigmoid(raw_logits).tolist() return np.array(y_hat)
[docs] @staticmethod def get_best_model_file_name(best_model_dir: str) -> str: model_files = [model_file for model_file in os.listdir(best_model_dir) if model_file.endswith(".pth") or model_file.endswith(".pt")] if len(model_files) == 1: return model_files[0] else: return None
@staticmethod def _format_metrics_output(metrics, set_label): message: List[str] = [set_label + " metrics:\n"] message += [" - " + metric + ": " + str(metrics[metric]) + "\n" for metric in metrics] return "".join(message).rstrip()
[docs] @staticmethod def train_model_basic( learner: Learner, training_dataset: torch.utils.data.Dataset, validation_dataset: torch.utils.data.Dataset, output_layer_activation_function: Optional[str] = None, silent: bool = False) -> None: logger = logging.getLogger(__name__) if silent: logger.setLevel(os.environ.get("LOGLEVEL", "WARNING")) if learner.model is None: learner.create_model() batch_size = int( learner.definition.training_process_hyperparameters["batch_size"]) # init data loaders training_loader = torch.utils.data.DataLoader( training_dataset, batch_size=batch_size, shuffle=bool(strtobool( learner.definition.training_process_hyperparameters["shuffle"]))) validation_loader = torch.utils.data.DataLoader( validation_dataset, batch_size=batch_size, shuffle=False) # GPU or CPU? learner.model = learner.model.to(learner.device) logger.info("using device: %s", learner.device_label) # training loop num_epochs: int = int( learner.definition.training_process_hyperparameters["epochs"]) for epoch in range(num_epochs): logger.info("epoch {}/{}".format(epoch + 1, num_epochs)) for phase in [c.DataSet.TRAINING, c.DataSet.VALIDATION]: if phase == c.DataSet.TRAINING: learner.model.train() data_loader = training_loader else: learner.model.eval() data_loader = validation_loader running_loss: float = 0.0 running_correct: int = 0 for x, y in data_loader: # transfer to device x = x.to(learner.device) y = y.to(learner.device) # zero the parameter gradients learner.optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == c.DataSet.TRAINING): y_hat = learner.model(x) loss = learner.criterion(y_hat, y) # backward + optimize only if in training phase if phase == c.DataSet.TRAINING: loss.backward() learner.optimizer.step() if output_layer_activation_function is not None: if output_layer_activation_function == "softmax": y_hat = torch.nn.functional.softmax( y_hat, dim=1) elif output_layer_activation_function == "sigmoid": y_hat = torch.sigmoid(y_hat) # statistics if learner.definition.task == c.TaskType.MULTI_CLASS_CLASSIFICATION: indices = torch.argmax(y_hat, dim=1) correct = torch.eq(indices, y).view(-1) elif learner.definition.task == c.TaskType.MULTI_LABEL_CLASSIFICATION: # binarize y_hat y_hat = torch.gt(y_hat, 0.5) y = y.type_as(y_hat) correct = torch.all(y == y_hat, dim=-1) running_correct += torch.sum(correct).item() running_loss += loss.item() * x.size(0) epoch_loss = running_loss / len(data_loader.dataset) epoch_acc = running_correct.float() / len(data_loader.dataset) logger.info("{} - loss: {:.3f}, accuracy: {:.3f}".format( phase, epoch_loss, epoch_acc))
[docs] @staticmethod def save_model(learner: Learner, file_name: Optional[str] = None) -> None: if not file_name: file_name = "saved_model.pth" # save session info learner.write_session_info() if os.path.dirname(file_name): os.makedirs(learner.output_dir + os.path.dirname(file_name)) torch.save(learner.model.state_dict(), learner.output_dir + file_name)
[docs] @staticmethod def write_session_info(learner: Learner) -> None: with open(learner.output_dir + "session-info.txt", "w") as session_file: session_file.write("seqgra package version: " + pkg_resources.require("seqgra")[0].version + "\n") session_file.write("PyTorch version: " + torch.__version__ + "\n") session_file.write("NumPy version: " + np.version.version + "\n") session_file.write("Python version: " + sys.version + "\n")
[docs] @staticmethod def load_model(learner: Learner, file_name: Optional[str] = None) -> None: if not file_name: file_name = "saved_model.pth" TorchHelper.create_model(learner) learner.model.load_state_dict(torch.load(learner.output_dir + file_name))
[docs] @staticmethod def get_num_params(learner: Learner) -> ModelSize: if learner.model is None: learner.create_model() num_trainable_params: int = sum(param.numel() for param in learner.model.parameters() if param.requires_grad) num_all_params: int = sum(param.numel() for param in learner.model.parameters()) return ModelSize(num_trainable_params, num_all_params - num_trainable_params)
[docs] @staticmethod def get_optimizer(optimizer_hyperparameters, model_parameters): if "optimizer" in optimizer_hyperparameters: optimizer = \ optimizer_hyperparameters["optimizer"].lower().strip() if "learning_rate" in optimizer_hyperparameters: learning_rate = float( optimizer_hyperparameters["learning_rate"].strip()) else: learning_rate = 0.001 if "rho" in optimizer_hyperparameters: rho = float( optimizer_hyperparameters["rho"].strip()) else: rho = 0.9 if "eps" in optimizer_hyperparameters: eps = float( optimizer_hyperparameters["eps"].strip()) else: eps = 1e-08 if "weight_decay" in optimizer_hyperparameters: weight_decay = float( optimizer_hyperparameters["weight_decay"].strip()) else: weight_decay = 0.0 if "momentum" in optimizer_hyperparameters: momentum = float( optimizer_hyperparameters["momentum"].strip()) else: momentum = 0.0 if "lr_decay" in optimizer_hyperparameters: lr_decay = float( optimizer_hyperparameters["lr_decay"].strip()) else: lr_decay = 0.0 if "initial_accumulator_value" in optimizer_hyperparameters: initial_accumulator_value = float( optimizer_hyperparameters["initial_accumulator_value"].strip()) else: initial_accumulator_value = 0.0 if "betas" in optimizer_hyperparameters: betas = literal_eval( optimizer_hyperparameters["betas"].strip()) else: betas = (0.9, 0.999) if "amsgrad" in optimizer_hyperparameters: amsgrad = bool(strtobool( optimizer_hyperparameters["amsgrad"].strip())) else: amsgrad = False if "lambd" in optimizer_hyperparameters: lambd = float( optimizer_hyperparameters["lambd"].strip()) else: lambd = 0.0001 if "alpha" in optimizer_hyperparameters: alpha = float( optimizer_hyperparameters["alpha"].strip()) else: alpha = 0.75 if "t0" in optimizer_hyperparameters: t0 = float( optimizer_hyperparameters["t0"].strip()) else: t0 = 1000000.0 if "max_iter" in optimizer_hyperparameters: max_iter = int( optimizer_hyperparameters["max_iter"].strip()) else: max_iter = 20 if "max_eval" in optimizer_hyperparameters: max_eval = int( optimizer_hyperparameters["max_eval"].strip()) else: max_eval = None if "tolerance_grad" in optimizer_hyperparameters: tolerance_grad = float( optimizer_hyperparameters["tolerance_grad"].strip()) else: tolerance_grad = 1e-07 if "tolerance_change" in optimizer_hyperparameters: tolerance_change = float( optimizer_hyperparameters["tolerance_change"].strip()) else: tolerance_change = 1e-09 if "history_size" in optimizer_hyperparameters: history_size = int( optimizer_hyperparameters["history_size"].strip()) else: history_size = 100 if "line_search_fn" in optimizer_hyperparameters: line_search_fn = \ optimizer_hyperparameters["line_search_fn"].strip() else: line_search_fn = None if "centered" in optimizer_hyperparameters: centered = bool(strtobool( optimizer_hyperparameters["centered"].strip())) else: centered = False if "etas" in optimizer_hyperparameters: etas = literal_eval(optimizer_hyperparameters["etas"].strip()) else: etas = (0.5, 1.2) if "step_sizes" in optimizer_hyperparameters: step_sizes = literal_eval( optimizer_hyperparameters["step_sizes"].strip()) else: step_sizes = (1e-06, 50) if "dampening" in optimizer_hyperparameters: dampening = float( optimizer_hyperparameters["dampening"].strip()) else: dampening = 0.0 if "nesterov" in optimizer_hyperparameters: nesterov = bool(strtobool( optimizer_hyperparameters["nesterov"].strip())) else: nesterov = False if optimizer == "adadelta": if "eps" not in optimizer_hyperparameters: eps = 1e-06 return torch.optim.Adadelta( model_parameters, lr=learning_rate, rho=rho, eps=eps, weight_decay=weight_decay) elif optimizer == "adagrad": if "learning_rate" not in optimizer_hyperparameters: learning_rate = 0.01 if "eps" not in optimizer_hyperparameters: eps = 1e-10 return torch.optim.Adagrad( model_parameters, lr=learning_rate, lr_decay=lr_decay, weight_decay=weight_decay, initial_accumulator_value=initial_accumulator_value, eps=eps) elif optimizer == "adam": return torch.optim.Adam( model_parameters, lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) elif optimizer == "adamw": if "weight_decay" not in optimizer_hyperparameters: weight_decay = 0.01 return torch.optim.AdamW( model_parameters, lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) elif optimizer == "sparseadam": return torch.optim.SparseAdam( model_parameters, lr=learning_rate, betas=betas, eps=eps) elif optimizer == "adamax": if "learning_rate" not in optimizer_hyperparameters: learning_rate = 0.002 return torch.optim.Adamax( model_parameters, lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay) elif optimizer == "asgd": if "learning_rate" not in optimizer_hyperparameters: learning_rate = 0.01 return torch.optim.ASGD( model_parameters, lr=learning_rate, lambd=lambd, alpha=alpha, t0=t0, weight_decay=weight_decay) elif optimizer == "lbfgs": if "learning_rate" not in optimizer_hyperparameters: learning_rate = 1.0 return torch.optim.LBFGS( model_parameters, lr=learning_rate, max_iter=max_iter, max_eval=max_eval, tolerance_grad=tolerance_grad, tolerance_change=tolerance_change, history_size=history_size, line_search_fn=line_search_fn) elif optimizer == "rmsprop": if "learning_rate" not in optimizer_hyperparameters: learning_rate = 0.01 if "alpha" not in optimizer_hyperparameters: alpha = 0.99 return torch.optim.RMSprop( model_parameters, lr=learning_rate, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=momentum, centered=centered) elif optimizer == "rprop": if "learning_rate" not in optimizer_hyperparameters: learning_rate = 0.01 return torch.optim.Rprop( model_parameters, lr=learning_rate, etas=etas, step_sizes=step_sizes) elif optimizer == "rprop": if "learning_rate" not in optimizer_hyperparameters: learning_rate = 0.01 return torch.optim.Rprop( model_parameters, lr=learning_rate, etas=etas, step_sizes=step_sizes) elif optimizer == "sgd": return torch.optim.SGD( model_parameters, lr=learning_rate, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov) else: raise Exception("unknown optimizer specified: " + optimizer) else: raise Exception("no optimizer specified")
[docs] @staticmethod def get_loss(loss_hyperparameters): if "loss" in loss_hyperparameters: loss = loss_hyperparameters["loss"].lower().replace( "_", "").strip() if loss == "crossentropyloss": return torch.nn.CrossEntropyLoss() elif loss == "nllloss": return torch.nn.NLLLoss() elif loss == "bceloss": return torch.nn.BCELoss() elif loss == "bcewithlogitsloss": return torch.nn.BCEWithLogitsLoss() elif loss == "l1loss": return torch.nn.L1Loss() elif loss == "mseloss": return torch.nn.MSELoss() elif loss == "smoothl1loss": return torch.nn.SmoothL1Loss() elif loss == "kldivloss": return torch.nn.KLDivLoss() elif loss == "marginrankingloss": return torch.nn.MarginRankingLoss() elif loss == "hingeembeddingloss": return torch.nn.HingeEmbeddingLoss() elif loss == "cosineembeddingloss": return torch.nn.CosineEmbeddingLoss() else: raise Exception("unknown loss specified: " + loss) else: raise Exception("no loss specified")
[docs] @staticmethod def get_metrics(learner: Learner, output_layer_activation_function: Optional[str] = None): logger = logging.getLogger(__name__) def thresholded_output_transform(output): y_hat, y = output y_hat = torch.round(y_hat) return y_hat, y def softmax_thresholded_output_transform(output): y_hat, y = output y_hat = torch.nn.functional.softmax(y_hat, dim=1) y_hat = torch.round(y_hat) return y_hat, y def sigmoid_thresholded_output_transform(output): y_hat, y = output y_hat = torch.sigmoid(y_hat) y_hat = torch.round(y_hat) return y_hat, y is_multilabel = learner.definition.task == c.TaskType.MULTI_LABEL_CLASSIFICATION metrics_dict = dict() for metric in learner.metrics: metric = metric.lower().strip() if metric == "loss": metrics_dict[metric] = Loss(learner.criterion) elif metric == "accuracy": if is_multilabel: if output_layer_activation_function is None: metrics_dict[metric] = Accuracy( thresholded_output_transform, is_multilabel=is_multilabel) elif output_layer_activation_function == "softmax": metrics_dict[metric] = Accuracy( softmax_thresholded_output_transform, is_multilabel=is_multilabel) elif output_layer_activation_function == "sigmoid": metrics_dict[metric] = Accuracy( sigmoid_thresholded_output_transform, is_multilabel=is_multilabel) else: metrics_dict[metric] = Accuracy( is_multilabel=is_multilabel) else: logger.warning("unknown metric: %s", metric) return metrics_dict