Source code for seqgra.learner.torch.torchdataset

"""MIT - CSAIL - Gifford Lab - seqgra

PyTorch DataSet class

@author: Konstantin Krismer
"""
from collections import deque
import random
from typing import Any, Deque, List, Tuple

import torch
import numpy as np

from seqgra.learner import Learner


[docs]class MultiClassDataSet(torch.utils.data.Dataset): def __init__(self, x, y=None): self.x = x self.y = y self.x = np.array(self.x).astype(np.float32) if self.y is not None: if not isinstance(self.y, np.ndarray): self.y = np.array(self.y) if self.y.dtype == np.bool: self.y = np.argmax(self.y.astype(np.int64), axis=1) def __len__(self): return len(self.x) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() if self.y is None: return self.x[idx] else: return self.x[idx], self.y[idx]
[docs]class MultiLabelDataSet(torch.utils.data.Dataset): def __init__(self, x, y=None): self.x = x self.y = y self.x = np.array(self.x).astype(np.float32) if self.y is not None: if not isinstance(self.y, np.ndarray): self.y = np.array(self.y) if self.y.dtype == np.bool: self.y = self.y.astype(np.float32) def __len__(self): return len(self.x) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() if self.y is None: return self.x[idx] else: return self.x[idx], self.y[idx]
[docs]class IterableMultiClassDataSet(torch.utils.data.IterableDataset): def __init__(self, file_name: str, learner: Learner, shuffle: bool = False, contains_y: bool = True, cache_size: int = 10000): self.file_name: str = file_name self.learner: Learner = learner self.shuffle: bool = shuffle self.contains_y: bool = contains_y self.cache_size: int = cache_size self.x_cache = None self.y_cache = None self.cache_index: int = cache_size def __iter__(self): with open(self.file_name, "r") as f: # skip header next(f) x, y = self._get_next_example(f) while x is not None: if self.contains_y: yield x, y else: yield x x, y = self._get_next_example(f) def _get_next_example(self, file_handle) -> Tuple[Any, Any]: if self.x_cache is None or self.cache_index >= self.x_cache.shape[0]: # read next chunk in memory x_vec: List[str] = list() y_vec: List[str] = list() line: str = file_handle.readline() i = 1 while line and i <= self.cache_size: cells: List[str] = line.split("\t") if len(cells) == 2 or (len(cells) == 1 and not self.contains_y): x_vec.append(cells[0].strip()) if self.contains_y: y_vec.append(cells[1].strip()) else: raise Exception("invalid example: " + line) line = file_handle.readline() i += 1 if x_vec: # validate data if self.learner.validate_data: self.learner.check_sequence(x_vec) if self.contains_y: self.learner.check_labels(y_vec) # shuffle if self.shuffle: if self.contains_y: temp = list(zip(x_vec, y_vec)) random.shuffle(temp) x_vec, y_vec = zip(*temp) else: random.shuffle(x_vec) # process chunk in memory encoded_x_vec = self.learner.encode_x(x_vec) if not isinstance(encoded_x_vec, np.ndarray): encoded_x_vec = np.array(encoded_x_vec) self.x_cache = encoded_x_vec.astype(np.float32) if self.contains_y: encoded_y_vec = self.learner.encode_y(y_vec) if not isinstance(encoded_y_vec, np.ndarray): encoded_y_vec = np.array(encoded_y_vec) self.y_cache = np.argmax(encoded_y_vec.astype(np.int64), axis=1) self.cache_index = 0 if self.x_cache is not None and self.cache_index < self.x_cache.shape[0]: if self.contains_y: example = (self.x_cache[self.cache_index, ...], self.y_cache[self.cache_index]) else: example = (self.x_cache[self.cache_index, ...], None) self.cache_index += 1 return example else: return (None, None)
[docs]class IterableMultiLabelDataSet(torch.utils.data.IterableDataset): def __init__(self, file_name: str, learner: Learner, shuffle: bool = False, contains_y: bool = True, cache_size: int = 10000): self.file_name: str = file_name self.learner: Learner = learner self.shuffle: bool = shuffle self.contains_y: bool = contains_y self.cache_size: int = cache_size self.x_cache = None self.y_cache = None self.cache_index: int = cache_size def __iter__(self): with open(self.file_name, "r") as f: # skip header next(f) x, y = self._get_next_example(f) while x is not None: if self.contains_y: yield x, y else: yield x x, y = self._get_next_example(f) def _get_next_example(self, file_handle) -> Tuple[Any, Any]: if self.x_cache is None or self.cache_index >= self.x_cache.shape[0]: # read next chunk in memory x_vec: List[str] = list() y_vec: List[str] = list() line: str = file_handle.readline() i = 1 while line and i <= self.cache_size: cells: List[str] = line.split("\t") if len(cells) == 2 or (len(cells) == 1 and not self.contains_y): x_vec.append(cells[0].strip()) if self.contains_y: y_vec.append(cells[1].strip()) else: raise Exception("invalid example: " + line) line = file_handle.readline() i += 1 if x_vec: # validate data if self.learner.validate_data: self.learner.check_sequence(x_vec) if self.contains_y: self.learner.check_labels(y_vec) # shuffle if self.shuffle: if self.contains_y: temp = list(zip(x_vec, y_vec)) random.shuffle(temp) x_vec, y_vec = zip(*temp) else: random.shuffle(x_vec) # process chunk in memory encoded_x_vec = self.learner.encode_x(x_vec) if not isinstance(encoded_x_vec, np.ndarray): encoded_x_vec = np.array(encoded_x_vec) self.x_cache = encoded_x_vec.astype(np.float32) if self.contains_y: encoded_y_vec = self.learner.encode_y(y_vec) if not isinstance(encoded_y_vec, np.ndarray): encoded_y_vec = np.array(encoded_y_vec) self.y_cache = encoded_y_vec.astype(np.float32) self.cache_index = 0 if self.x_cache is not None and self.cache_index < self.x_cache.shape[0]: if self.contains_y: example = (self.x_cache[self.cache_index, ...], self.y_cache[self.cache_index]) else: example = (self.x_cache[self.cache_index, ...], None) self.cache_index += 1 return example else: return (None, None)