Source code for mycelyso.pilyso.pipeline.executor

# -*- coding: utf-8 -*-
"""
The executor submodule contains the PipelineExecutor, which runs an image processing pipeline.
"""

import gc
import logging
try:
    import tqdm
except ImportError:
    tqdm = None
from itertools import product
from collections import OrderedDict

from copy import deepcopy


from ..misc.processpool import cpu_count, SimpleProcessPool, InProcessFakePool, WrappedException
from .pipeline import PipelineEnvironment


[docs]class Collected(object): pass
[docs]class Every(object): pass
[docs]class NotDispatchedYet(object): pass
exception_debugging = False singleton_class_mapper_local_cache = {}
[docs]def singleton_class_mapper(class_, what, args, kwargs): try: if class_ not in singleton_class_mapper_local_cache: singleton_class_mapper_local_cache[class_] = class_.__new__(class_) result = getattr(singleton_class_mapper_local_cache[class_], what)(*args, **kwargs) gc.collect() return result except Exception as _: raise
[docs]def get_progress_bar(n): if tqdm: return iter(tqdm.tqdm(range(n))) else: return iter(range(n))
# noinspection PyMethodMayBeStatic
[docs]class PipelineExecutor(object): wait = 0.01 multiprocessing = True def __init__(self): if self.multiprocessing: if self.multiprocessing is True: self.multiprocessing = cpu_count()
[docs] def set_workload(self, meta_tuple, item_counts, processing_order): self.meta_tuple = meta_tuple self.item_counts = item_counts self.processing_order = processing_order
[docs] def initialize(self, pec, *args, **kwargs): self.log = logging.getLogger(__name__) self.pec = pec complete_args = (self.pec, '__init__', args, kwargs,) self.complete_args = complete_args singleton_class_mapper(*complete_args) if self.multiprocessing: self.pool = SimpleProcessPool( processes=self.multiprocessing, initializer=singleton_class_mapper, initargs=complete_args, future_timeout=30.0 * 60, # five minute timeout, only works with the self-written pool ) else: self.pool = InProcessFakePool()
[docs] def set_progress_total(self, l): self.progress_indicator = get_progress_bar(l)
[docs] def progress_tick(self): try: next(self.progress_indicator) except StopIteration: pass
# noinspection PyUnusedLocal
[docs] def in_cache(self, token): return False
# noinspection PyUnusedLocal
[docs] def get_cache(self, token): pass
# noinspection PyUnusedLocal
[docs] def set_cache(self, token, result): if result is None: return
[docs] def skip_callback(self, op, skipped): self.log.info("Operation %r caused Skip for %r.", op, skipped)
[docs] def run(self): meta_tuple = self.meta_tuple if getattr(self, 'cache', False) is False: self.cache = False sort_order = [index for index, _ in sorted(enumerate(self.processing_order), key=lambda p: p[0])] def prepare_steps(step, replace): return list(meta_tuple(*t) for t in sorted(product(*[ self.item_counts[num] if value == replace else [value] for num, value in enumerate(step) ]), key=lambda t: [t[i] for i in sort_order])) todo = OrderedDict() reverse_todo = {} results = {} mapping = {} reverse_mapping = {} steps = singleton_class_mapper(self.pec, 'get_step_keys', (), {}) for step in steps: order = prepare_steps(step, Every) reverse_todo.update({k: step for k in order}) for k in order: todo[k] = NotDispatchedYet deps = {t: set(prepare_steps(t, Collected)) for t in order} mapping.update(deps) for key, value in deps.items(): for k in value: if k not in reverse_mapping: reverse_mapping[k] = {key} else: reverse_mapping[k] |= {key} mapping_copy = deepcopy(mapping) def is_concrete(t): for n in t: if n is Collected or n is Every: return False return True # initial_length = len(todo) self.set_progress_total(len(todo)) check = OrderedDict() cache_originated = set() invalidated = set() concrete_counter, non_concrete_counter = 0, 0 while len(todo) > 0 or len(check) > 0: for op in list(todo.keys()): result = None if op not in invalidated: parameter_dict = {'meta': op} if is_concrete(op): # we are talking about a definite point, that is one that is not dependent on others concrete_counter += 1 priority = 1 * concrete_counter elif len(mapping[op]) != 0: continue else: collected = OrderedDict() for fetch in sorted(mapping_copy[op], key=lambda t: [t[i] for i in sort_order]): collected[fetch] = results[fetch] parameter_dict[PipelineEnvironment.KEY_COLLECTED] = collected non_concrete_counter += 1 priority = -1 * non_concrete_counter token = (reverse_todo[op], op,) if self.in_cache(token): cache_originated.add(op) raise RuntimeError('TODO') # TODO # result = self.pool.advanced_apply( # command=singleton_class_mapper, # args=(self.__class__, 'get_cache', (token,), {},), # priority=priority # ) else: result = self.pool.advanced_apply( singleton_class_mapper, args=(self.pec, 'dispatch', (reverse_todo[op],), parameter_dict,), priority=priority ) results[op] = result check[op] = True del todo[op] for op in list(check.keys()): result = results[op] modified = False if op in invalidated: if getattr(result, 'fail', False) and callable(result.fail): result.fail() modified = True else: if self.wait: result.wait(self.wait) if result.ready(): try: result = result.get() if op not in cache_originated: token = (reverse_todo[op], op,) # so far, solely accessing (write) the cache from # one process should mitigate locking issues self.set_cache(token, result) except WrappedException as ee: e = ee.exception if type(e) == Skip: old_invalid = invalidated.copy() def _add_to_invalid(what): if what not in invalidated: invalidated.add(what) if what in mapping_copy: for item in mapping_copy[what]: _add_to_invalid(item) _add_to_invalid(e.meta) new_invalid = invalidated - old_invalid self.skip_callback(op, new_invalid) else: if exception_debugging: raise self.log.exception("Exception occurred at op=%s: %s", repr(reverse_todo[op]) + ' ' + repr(op), ee.message) result = None modified = True if modified: results[op] = result if op in reverse_mapping: for affected in reverse_mapping[op]: mapping[affected] -= {op} del check[op] self.progress_tick() self.progress_tick() self.close()
[docs] def close(self): self.pool.close()