Source code for mycelyso.highlevel.pipeline

# -*- coding: utf-8 -*-
"""
The pipeline module contains the mycelyso-Pipeline, assembled from various functions.
"""

from os.path import basename, abspath
# noinspection PyUnresolvedReferences
import networkx as nx

from tunable import TunableManager

from ..tunables import CropWidth, CropHeight, BoxDetection, StoreImage, SkipBinarization
from .steps import *
from .. import __banner__
from .. import __version__
# noinspection PyUnresolvedReferences
from ..pilyso.application import App, PipelineExecutionContext, PipelineEnvironment, Every, Collected, Meta, Skip
from ..pilyso.imagestack import ImageStack
from ..pilyso.misc.h5writer import hdf5_output, hdf5_node_name, return_or_uncompress
from ..pilyso.pipeline.pipeline import NeatDict
from ..pilyso.steps import \
    image_source, pull_metadata_from_image, substract_start_frame, rescale_image_to_uint8, set_result, Delete, \
    box_detection, create_boxcrop_from_subtracted_image, calculate_image_sha256_hash, Compress, copy_calibration


[docs]class Mycelyso(App): """ The Mycelyso App, implementing a pilyso App. """
[docs] def options(self): return { 'name': "mycelyso", 'description': "", 'banner': __banner__, 'pipeline': MycelysoPipeline }
[docs] def arguments(self, argparser): argparser.add_argument('--meta', '--meta', dest='meta', default='') argparser.add_argument('--interactive', '--interactive', dest='interactive', default=False, action='store_true') argparser.add_argument('--output', '--output', dest='output', default='output.h5')
[docs] def handle_args(self): self.args.tunables = TunableManager.get_representation() if self.args.interactive: # if interactive, don't spawn workers self.args.processes = 0 self.run = self.interactive_run
[docs] def interactive_run(self): pipeline, fun, args, kwargs = self.pe.complete_args assert fun == '__init__' pipeline = pipeline(*args, **kwargs) import matplotlib.pyplot as plt from matplotlib.widgets import Slider fig, ax = plt.subplots() plt.subplots_adjust(left=0.25, bottom=0.25) fig.canvas.set_window_title("Image Viewer") slider_background = '#e7af12' slider_foreground = '#005b82' ax_mp = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=slider_background) ax_tp = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=slider_background) mp_max = max(self.positions) tp_max = max(self.timepoints) with warnings.catch_warnings(): warnings.simplefilter('ignore') multipoint = Slider(ax_mp, 'Multipoint', 0, mp_max, valinit=0, valfmt="%d", color=slider_foreground) timepoint = Slider(ax_tp, 'Timepoint', 0, tp_max, valinit=0, valfmt="%d", color=slider_foreground) env = {'show': True} def update(_): t = int(timepoint.val) position = int(multipoint.val) fig.canvas.set_window_title("Image Viewer - [BUSY]") plt.rcParams['image.cmap'] = 'gray' plt.sca(ax) plt.cla() plt.suptitle('[left/right] timepoint [up/down] multipoint [h] hide analysis') result = pipeline.dispatch(Meta(pos=Every, t=Every), meta=Meta(pos=position, t=t)) result = NeatDict(result) binary = return_or_uncompress(result.binary) skeleton = return_or_uncompress(result.skeleton) node_draw_parameters = dict(node_size=25, node_color='darkgray', linewidths=0, edge_color='darkgray') graph, pos = result.node_frame.get_networkx_graph(return_positions=True) nx.draw_networkx( graph, with_labels=False, pos=pos, **node_draw_parameters ) if env['show']: plt.imshow(binary, cmap='gray_r') else: plt.imshow(skeleton, cmap='gray_r') fig.canvas.set_window_title("Image Viewer - %s timepoint #%d %d/%d multipoint #%d %d/%d" % (self.args.input, t, 1 + t, 1 + tp_max, position, 1 + position, 1 + mp_max)) plt.draw() update(None) multipoint.on_changed(update) timepoint.on_changed(update) def key_press(event): if event.key == 'left': timepoint.set_val(max(1, int(timepoint.val) - 1)) elif event.key == 'right': timepoint.set_val(min(tp_max, int(timepoint.val) + 1)) elif event.key == 'ctrl+left': timepoint.set_val(max(1, int(timepoint.val) - 10)) elif event.key == 'ctrl+right': timepoint.set_val(min(tp_max, int(timepoint.val) + 10)) elif event.key == 'down': multipoint.set_val(max(1, int(multipoint.val) - 1)) elif event.key == 'up': multipoint.set_val(min(mp_max, int(multipoint.val) + 1)) elif event.key == 'h': env['show'] = not env['show'] update(None) elif event.key == 'q': raise SystemExit fig.canvas.mpl_connect('key_press_event', key_press) with warnings.catch_warnings(): warnings.simplefilter('ignore') fig.tight_layout() plt.show()
[docs]class MycelysoPipeline(PipelineExecutionContext): """ The MycelysoPipeline, defining the pipeline (with slight alterations based upon arguments passed via command line). """ def __init__(self, args): TunableManager.load(args.tunables) absolute_input = abspath(args.input) h5nodename = hdf5_node_name(absolute_input) self.pipeline_environment = PipelineEnvironment(ims=ImageStack(args.input)) per_image = self.add_stage(Meta(pos=Every, t=Every)) per_image |= set_result(tunables_hash=TunableManager.get_hash()) # read the image per_image |= image_source per_image |= calculate_image_sha256_hash per_image |= pull_metadata_from_image per_image |= lambda image, raw_image=None: image per_image |= lambda image, raw_unrotated_image=None: image per_image |= set_empty_crops # define what we want (per image) as results result_table = { '_plain': [ 'calibration', 'timepoint', 'input_height', 'input_width', 'area', 'covered_ratio', 'covered_area', 'graph_edge_length', 'graph_edge_count', 'graph_node_count', 'graph_junction_count', 'graph_endpoint_count', 'filename', 'metadata', 'shift_x', 'shift_y', 'crop_t', 'crop_b', 'crop_l', 'crop_r', 'image_sha256_hash', 'tunables_hash' ], 'graphml': 'data', # 'image': 'image', # 'raw_unrotated_image': 'image', # 'raw_image': 'image', 'skeleton': 'image', 'binary': 'image' } if StoreImage.value: result_table['image'] = 'image' per_image |= set_result( reference_timepoint=1, filename_complete=absolute_input, filename=basename(absolute_input), metadata=args.meta, result_table=result_table ) per_image |= substract_start_frame if BoxDetection.value: per_image |= box_detection per_image |= create_boxcrop_from_subtracted_image per_image |= rescale_image_to_uint8 per_image |= set_result(raw_unrotated_image=Delete, raw_image=Delete, subtracted_image=Delete) per_image |= lambda image: image[ CropHeight.value:-(CropHeight.value if CropHeight.value > 0 else 1), CropWidth.value:-(CropWidth.value if CropWidth.value > 0 else 1) ] per_image |= lambda crop_t, crop_b, crop_l, crop_r: ( crop_t + CropHeight.value, crop_b - CropHeight.value, crop_l + CropWidth.value, crop_r - CropWidth.value ) per_image |= skip_if_image_is_below_size(4, 4) # generate statistics of the image per_image |= image_statistics if not SkipBinarization.value: # binarize per_image |= binarize else: # noinspection PyUnusedLocal def _image_to_binary(image, binary=None): return image.astype(bool) per_image |= _image_to_binary # ... and cleanup per_image |= clean_up per_image |= remove_small_structures per_image |= remove_border_artifacts # generate statistics of the binarized image per_image |= quantify_binary per_image |= skeletonize # 'binary', 'skeleton' are kept! per_image |= convert_to_nodes if not StoreImage.value: per_image |= set_result(image=Delete) per_image |= set_result(pixel_frame=Delete) per_image |= graph_statistics per_image |= generate_graphml per_image |= set_result(binary=Compress, skeleton=Compress, graphml=Compress) # per position per_position = self.add_stage(Meta(pos=Every, t=Collected)) per_position |= copy_calibration per_position |= track_multipoint per_position |= generate_overall_graphml per_position |= individual_tracking per_position |= prepare_tracked_fragments per_position |= prepare_position_regressions per_position |= lambda meta, meta_pos=None: meta.pos per_position |= set_result( filename_complete=absolute_input, filename=basename(absolute_input), metadata=args.meta, tunables=TunableManager.get_serialization(), version=__version__, banner=__banner__, result_table={ '_plain': [ 'metadata', 'filename_complete', 'filename', 'meta_pos', 'calibration', '*_regression_*' ], 'tunables': 'data', 'version': 'data', 'banner': 'data', 'overall_graphml': 'data', 'track_table': 'table', 'track_table_aux_tables': 'table' } ) per_position |= hdf5_output(args.output, h5nodename) def black_hole(result): for k in list(result.keys()): del result[k] del result return {} per_position |= black_hole