diff --git a/Mikado/parsers/__init__.py b/Mikado/parsers/__init__.py index 558390319..a50330bc2 100644 --- a/Mikado/parsers/__init__.py +++ b/Mikado/parsers/__init__.py @@ -11,6 +11,7 @@ import bz2 from functools import partial import magic +import multiprocessing as mp class HeaderError(Exception): @@ -27,17 +28,21 @@ class Parser(metaclass=abc.ABCMeta): def __init__(self, handle): self.__closed = False + self.__from_queue = False if not isinstance(handle, io.IOBase): - if handle.endswith(".gz") or self.wizard.from_file(handle) == b"application/gzip": - opener = gzip.open - elif handle.endswith(".bz2") or self.wizard.from_file(handle) == b"application/x-bzip2": - opener = bz2.open + if isinstance(handle, mp.queues.Queue): + self.__from_queue = True else: - opener = partial(open, **{"buffering": 1}) - try: - handle = opener(handle, "rt") - except FileNotFoundError: - raise FileNotFoundError("File not found: {0}".format(handle)) + if handle.endswith(".gz") or self.wizard.from_file(handle) == b"application/gzip": + opener = gzip.open + elif handle.endswith(".bz2") or self.wizard.from_file(handle) == b"application/x-bzip2": + opener = bz2.open + else: + opener = partial(open, **{"buffering": 1}) + try: + handle = opener(handle, "rt") + except FileNotFoundError: + raise FileNotFoundError("File not found: {0}".format(handle)) self._handle = handle self.closed = False @@ -46,8 +51,20 @@ def __iter__(self): return self def __next__(self): - line = self._handle.readline() - return line + + if self.__from_queue: + line = self._handle.get_nowait() + if isinstance(line, bytes): + line = line.decode() + if line in ("EXIT", b"EXIT"): + self.close() + else: + try: + line = self._handle.readline() + except StopIteration: + self.close() + raise StopIteration + return line def __enter__(self): if self.closed is True: @@ -56,7 +73,10 @@ def __enter__(self): def __exit__(self, *args): _ = args - self._handle.close() + if self.__from_queue is False: + self._handle.close() + else: + self._handle.join() self.closed = True def close(self): @@ -70,7 +90,10 @@ def name(self): """ Return the filename. """ - return self._handle.name + if self.__from_queue: + return "" + else: + return self._handle.name @property def closed(self): @@ -103,6 +126,7 @@ def closed(self, *args): from . import blast_utils from . import bam_parser + def to_gff(string, input_format=None): """ Function to recognize the input file type (GFF or GTF). diff --git a/Mikado/parsers/bed12.py b/Mikado/parsers/bed12.py index 15a1b7bf5..8bbbcb1aa 100644 --- a/Mikado/parsers/bed12.py +++ b/Mikado/parsers/bed12.py @@ -25,6 +25,9 @@ from Bio.Data.IUPACData import ambiguous_rna_letters as _ambiguous_rna_letters from Bio.Data import CodonTable import multiprocessing as mp +import msgpack +import logging +import logging.handlers as logging_handlers backup_valid_letters = set(_ambiguous_dna_letters.upper() + _ambiguous_rna_letters.upper()) @@ -125,7 +128,9 @@ def _translate_str(sequence, table, stop_symbol="*", to_stop=False, cds=False, p """ if cds and len(sequence) % 3 != 0: - raise CodonTable.TranslationError("Sequence length {0} is not a multiple of three".format(n)) + raise CodonTable.TranslationError("Sequence length {0} is not a multiple of three".format( + len(sequence) + )) elif gap is not None and (not isinstance(gap, str) or len(gap) > 1): raise TypeError("Gap character should be a single character " "string.") @@ -719,6 +724,25 @@ def copy(self): return copy.deepcopy(self) + def as_simple_dict(self): + + return { + "chrom": self.chrom, + "id": self.id, + "start": self.start, + "end": self.end, + "name": self.name, + "strand": self.strand, + "thick_start": self.thick_start, + "thick_end": self.thick_end, + "score": self.score, + "has_start_codon": self.has_start_codon, + "has_stop_codon": self.has_stop_codon, + "cds_len": self.cds_len, + "phase": self.phase, + "transcriptomic": self.transcriptomic, + } + @property def strand(self): """ @@ -1138,7 +1162,6 @@ def __init__(self, handle, max_regression=0, is_gff=False, coding=False, - procs=None, table=0): """ Constructor method. @@ -1167,11 +1190,13 @@ def __init__(self, handle, fasta_index[numpy.random.choice(fasta_index.keys(), 1)], Bio.SeqRecord.SeqRecord) elif fasta_index is not None: - if isinstance(fasta_index, str): + if isinstance(fasta_index, (str, bytes)): + if isinstance(fasta_index, bytes): + fasta_index = fasta_index.decode() assert os.path.exists(fasta_index) fasta_index = pysam.FastaFile(fasta_index) else: - assert isinstance(fasta_index, pysam.FastaFile) + assert isinstance(fasta_index, pysam.FastaFile), type(fasta_index) self.fasta_index = fasta_index self.__closed = False @@ -1267,3 +1292,118 @@ def coding(self, coding): if coding not in (False, True): raise ValueError(coding) self.__coding = coding + + +class Bed12ParseWrapper(mp.Process): + + def __init__(self, + rec_queue=None, + return_queue=None, + log_queue=None, level="DEBUG", + fasta_index=None, + transcriptomic=False, + max_regression=0, + is_gff=False, + coding=False, + table=0): + + """ + :param send_queue: + :type send_queue: mp.Queue + :param return_queue: + :type send_queue: mp.Queue + :param kwargs: + """ + + super().__init__() + self.rec_queue = rec_queue + self.return_queue = return_queue + self.logging_queue = log_queue + self.handler = logging_handlers.QueueHandler(self.logging_queue) + self.logger = logging.getLogger(self.name) + self.logger.addHandler(self.handler) + self.logger.setLevel(level) + self.logger.propagate = False + self.transcriptomic = transcriptomic + self.__max_regression = 0 + self._max_regression = max_regression + self.coding = coding + + if isinstance(fasta_index, dict): + # check that this is a bona fide dictionary ... + assert isinstance( + fasta_index[numpy.random.choice(fasta_index.keys(), 1)], + Bio.SeqRecord.SeqRecord) + elif fasta_index is not None: + if isinstance(fasta_index, (str, bytes)): + if isinstance(fasta_index, bytes): + fasta_index = fasta_index.decode() + assert os.path.exists(fasta_index) + fasta_index = pysam.FastaFile(fasta_index) + else: + assert isinstance(fasta_index, pysam.FastaFile), type(fasta_index) + + self.fasta_index = fasta_index + self.__closed = False + self.header = False + self.__table = table + self._is_bed12 = (not is_gff) + + def bed_next(self, line): + """ + + :return: + """ + + bed12 = BED12(line, + fasta_index=self.fasta_index, + transcriptomic=self.transcriptomic, + max_regression=self._max_regression, + coding=self.coding, + table=self.__table) + return bed12 + + def gff_next(self, line): + """ + + :return: + """ + + line = GffLine(line) + + if line.feature != "CDS": + return None + # Compatibility with BED12 + bed12 = BED12(line, + fasta_index=self.fasta_index, + transcriptomic=self.transcriptomic, + max_regression=self._max_regression, + table=self.__table) + # raise NotImplementedError("Still working on this!") + return bed12 + + def run(self, *args, **kwargs): + while True: + line = self.rec_queue.get() + if line in ("EXIT", b"EXIT"): + self.rec_queue.put(b"EXIT") + break + try: + line = line.decode() + except AttributeError: + pass + if not self._is_bed12: + row = self.gff_next(line) + else: + row = self.bed_next(line) + + if not row or row.header is True: + continue + if row.invalid is True: + self.logger.warn("Invalid entry, reason: %s\n%s", + row.invalid_reason, + row) + continue + self.return_queue.put(msgpack.dumps(row.as_simple_dict())) + + # self.join() diff --git a/Mikado/serializers/orf.py b/Mikado/serializers/orf.py index 474d31682..ee572b2a2 100644 --- a/Mikado/serializers/orf.py +++ b/Mikado/serializers/orf.py @@ -19,7 +19,9 @@ from ..utilities.log_utils import create_null_logger, check_logger import pandas as pd from ..exceptions import InvalidSerialization +from ..parsers import Parser import multiprocessing as mp +import msgpack # This is a serialization class, it must have a ton of attributes ... @@ -113,7 +115,6 @@ def create_dict(bed12_object, query_id): "phase": bed12_object.phase } - @classmethod def as_bed12_static(cls, state, query_name): """Class method to transform the mapper into a BED12 object. @@ -197,10 +198,10 @@ def __init__(self, fasta_index = json_conf["serialise"]["files"]["transcripts"] self._max_regression = json_conf["serialise"]["max_regression"] self._table = json_conf["serialise"]["codon_table"] - # self.procs = json_conf["threads"] - # self.single_thread = json_conf["serialise"]["single_thread"] - # if self.single_thread: - # self.procs = 1 + self.procs = json_conf["threads"] + self.single_thread = json_conf["serialise"]["single_thread"] + if self.single_thread: + self.procs = 1 if isinstance(fasta_index, str): assert os.path.exists(fasta_index) @@ -214,19 +215,13 @@ def __init__(self, self.fasta_index = fasta_index if isinstance(handle, str): - self.is_bed12 = (handle.endswith("bed12") or handle.endswith("bed")) + self.is_bed12 = (".bed12" in handle or ".bed" in handle) else: - self.is_bed12 = (handle.name.endswith("bed12") or handle.name.endswith("bed")) - - self.bed12_parser = bed12.Bed12Parser(handle, - fasta_index=fasta_index, - is_gff=(not self.is_bed12), - transcriptomic=True, - max_regression=self._max_regression, - table=self._table) + self.is_bed12 = (".bed12" in handle.name or ".bed" in handle.name.endswith) self.engine = connect(json_conf, logger) + self._handle = handle Session = sessionmaker(bind=self.engine, autocommit=False, autoflush=False, expire_on_commit=False) session = Session() # session.configure(bind=self.engine) @@ -291,25 +286,15 @@ def load_fasta(self): self.logger.debug("Finished loading %d transcripts into query table", done) return - def serialize(self): - """ - This method performs the parsing of the ORF file and the - loading into the SQL database. - """ + def __serialize_single_thread(self): + self.bed12_parser = bed12.Bed12Parser(self._handle, + fasta_index=self.fasta_index, + is_gff=(not self.is_bed12), + transcriptomic=True, + max_regression=self._max_regression, + table=self._table) objects = [] - # Dictionary to hold the data before bulk loading into the database - - cache = dict() - for record in self.session.query(Query): - cache[record.query_name] = record.query_id - - self.load_fasta() - # Reload - - cache = pd.read_sql_table("query", self.engine, index_col="query_name", columns=["query_name", "query_id"]) - cache = cache.to_dict()["query_id"] - initial_cache = (len(cache) > 0) done = 0 not_found = set() @@ -321,14 +306,14 @@ def serialize(self): row.invalid_reason, row) continue - if row.id in cache: - current_query = cache[row.id] - elif not initial_cache: + if row.id in self.cache: + current_query = self.cache[row.id] + elif not self.initial_cache: current_query = Query(row.id, row.end) not_found.add(row.id) self.session.add(current_query) self.session.commit() - cache[current_query.query_name] = current_query.query_id + self.cache[current_query.query_name] = current_query.query_id current_query = current_query.query_id else: self.logger.critical( @@ -365,6 +350,103 @@ def serialize(self): if orfs.shape[0] != done: raise ValueError("I should have serialised {} ORFs, but {} are present!".format(done, orfs.shape[0])) + def __serialize_multiple_threads(self): + """""" + + send_queue = mp.JoinableQueue(-1) + return_queue = mp.JoinableQueue(-1) + + parsers = [bed12.Bed12ParseWrapper( + rec_queue=send_queue, + return_queue=return_queue, + fasta_index=self.fasta_index.filename, + is_gff=(not self.is_bed12), + transcriptomic=True, + max_regression=self._max_regression, + table=self._table) for _ in range(self.procs)] + + [_.start() for _ in parsers] + + for line in open(self._handle): + send_queue.put(line.encode()) + send_queue.put("EXIT") + not_found = set() + done = 0 + objects = [] + [parser.join() for parser in parsers] + return_queue.put("EXIT") + while True: + try: + object = return_queue.get_nowait() + except mp.queues.Empty: + break + if object in ("EXIT", b"EXIT"): + break + object = msgpack.loads(object, raw=False) + + if object["id"] in self.cache: + current_query = self.cache[object["id"]] + elif not self.initial_cache: + current_query = Query(object["id"], object["end"]) + not_found.add(object["id"]) + self.session.add(current_query) + self.session.commit() + self.cache[current_query.query_name] = current_query.query_id + current_query = current_query.query_id + else: + self.logger.critical( + "The provided ORFs do not match the transcripts provided and already present in the database.\ +Please check your input files.") + raise InvalidSerialization + + object["query_id"] = current_query + objects.append(object) + if len(objects) >= self.maxobjects: + done += len(objects) + self.session.begin(subtransactions=True) + # self.session.bulk_save_objects(objects) + self.engine.execute( + Orf.__table__.insert(), + objects + ) + self.session.commit() + self.logger.debug("Loaded %d ORFs into the database", done) + objects = [] + + done += len(objects) + # self.session.begin(subtransactions=True) + # self.session.bulk_save_objects(objects, update_changed_only=False) + if objects: + self.engine.execute( + Orf.__table__.insert(), + objects + ) + return_queue.close() + send_queue.close() + self.session.commit() + self.session.close() + self.logger.info("Finished loading %d ORFs into the database", done) + + orfs = pd.read_sql_table("orf", self.engine, index_col="query_id") + if orfs.shape[0] != done: + raise ValueError("I should have serialised {} ORFs, but {} are present!".format(done, orfs.shape[0])) + + def serialize(self): + """ + This method performs the parsing of the ORF file and the + loading into the SQL database. + """ + + self.load_fasta() + self.cache = pd.read_sql_table("query", self.engine, index_col="query_name", columns=["query_name", "query_id"]) + self.cache = self.cache.to_dict()["query_id"] + self.initial_cache = (len(self.cache) > 0) + + if self.procs == 1: + self.__serialize_single_thread() + else: + self.__serialize_multiple_threads() + def __call__(self): """ Alias for serialize