Source code for seqgra.evaluator.gradientbased.deepliftevaluator

"""DeepLIFT Evaluator
"""
import types
from typing import Optional

import torch
from torch.autograd import Variable
import torch.nn.functional as F

import seqgra.constants as c
from seqgra.evaluator.gradientbased import AbstractGradientEvaluator
from seqgra.learner import Learner


[docs]class DeepLiftEvaluator(AbstractGradientEvaluator): """DeepLIFT evaluator for PyTorch models """ # TODO where to set reference? def __init__(self, learner: Learner, output_dir: str, importance_threshold: Optional[float] = None, baseline_type: str = "shuffled", silent: bool = False) -> None: super().__init__(c.EvaluatorID.DEEP_LIFT, "DeepLIFT", learner, output_dir, importance_threshold, silent=silent) self._prepare_reference() assert(baseline_type in ["neutral", "zeros", "shuffled"]) self.baseline_inp = None self.baseline_type = baseline_type self._override_backward()
[docs] def explain(self, x, y): self._reset_preference() self._baseline_forward(x) grad = self._backprop(x, y) return x.data * grad
def _prepare_reference(self): def init_refs(m): name = m.__class__.__name__ if name.find("ReLU") != -1: m.ref_inp_list = [] m.ref_out_list = [] def ref_forward(self, x): self.ref_inp_list.append(x.data.clone()) out = F.relu(x) self.ref_out_list.append(out.data.clone()) return out def ref_replace(m): name = m.__class__.__name__ if name.find("ReLU") != -1: m.forward = types.MethodType(ref_forward, m) self.learner.model.apply(init_refs) self.learner.model.apply(ref_replace) def _reset_preference(self): def reset_refs(m): name = m.__class__.__name__ if name.find("ReLU") != -1: m.ref_inp_list = [] m.ref_out_list = [] self.learner.model.apply(reset_refs) def _baseline_forward(self, inp): if self.baseline_inp is None: self.baseline_inp = inp.data.clone() if self.baseline_type == "neutral": self.baseline_inp.fill_(0.25) elif self.baseline_type == "zeros": self.baseline_inp.fill_(0.0) elif self.baseline_type == "shuffled": self.baseline_inp = self.baseline_inp[:, :, torch.randperm( self.baseline_inp.size()[2])] # TODO baseline_inp with shuffled k-mers?? self.baseline_inp = Variable(self.baseline_inp) # get ref _ = self.learner.model(self.baseline_inp) def _override_backward(self): def new_backward(self, grad_out): ref_inp, inp = self.ref_inp_list ref_out, out = self.ref_out_list delta_out = out - ref_out delta_in = inp - ref_inp g1 = (delta_in.abs() > 1e-5).float() * grad_out * \ delta_out / delta_in mask = ((ref_inp + inp) > 0).float() g2 = (delta_in.abs() <= 1e-5).float() * 0.5 * mask * grad_out return g1 + g2 def backward_replace(m): name = m.__class__.__name__ if name.find("ReLU") != -1: m.backward = types.MethodType(new_backward, m) self.learner.model.apply(backward_replace)