# -*- coding: utf-8 -*-

The nodeframe module contains the NodeFrame class, a representation of the graph of one time lapse frame.

from itertools import chain

import numpy as np
from networkx import find_cycle, NetworkXNoCycle, from_scipy_sparse_matrix
from scipy.sparse import lil_matrix
from scipy.sparse.csgraph import shortest_path, connected_components
from scipy.spatial.ckdtree import cKDTree as KDTree

from ..tunables import NodeEndpointMergeRadius, NodeJunctionMergeRadius, NodeLookupRadius, \
    NodeLookupCutoffRadius, NodeTrackingJunctionShiftRadius, NodeTrackingEndpointShiftRadius

from ..misc.util import calculate_length, clean_by_radius

[docs]class NodeFrame(object): """ Node frame is a representation of an image stack frame on the graph/node level, it populates its values from a PixelFrame passed. """ def __init__(self, pf): """ Initializes the NodeFrame :param pf: PixelFrame the NodeFrame corresponds to """ # initializing vars self.timepoint = None self.calibration = None self.junction_shift = None self.endpoint_shift = None = None self.endpoint_tree = None self.junction_tree = None self.endpoint_tree_data = None self.junction_tree_data = None self.adjacency = None self.every_endpoint = None self.every_junction = None self.predecessors = None self.shortest_paths = None self.shortest_paths_num = None self.connected_components_count = None self.connected_components = None self._cycles = None self.self_to_successor = None self.successor_to_self = None self.self_to_successor_alternatives = None # /initializing vars self.timepoint = pf.timepoint # copy this information, so we can set pf to None for serialization self.calibration = pf.calibration self.prepare_graph(pf)
[docs] def prepare_graph(self, pf): """ Prepares the graph from the data stored in the PixelFrame pf. :param pf: PixelFrame :return: """ endpoint_tree_data = clean_by_radius(pf.endpoints, NodeEndpointMergeRadius.value / self.calibration) junction_tree_data = clean_by_radius(pf.junctions, NodeJunctionMergeRadius.value / self.calibration) e_length = len(endpoint_tree_data) j_length = len(junction_tree_data) total_length = e_length + j_length data = np.r_[endpoint_tree_data, junction_tree_data] endpoint_tree_data = data[:e_length] junction_tree_data = data[e_length:] if e_length > 0: endpoint_tree = KDTree(endpoint_tree_data) else: endpoint_tree = None if j_length > 0: junction_tree = KDTree(junction_tree_data) else: junction_tree = None junction_shift = e_length endpoint_shift = 0 # while ends and junctions need to remain different, # they are put in the same graph / adjacency matrix # so, first come end nodes, then junction nodes # => shifts adjacency = lil_matrix((total_length, total_length), dtype=float) # little bit of nomenclature: # a pathlet (pixel graph so to say) is a path of on the image # its begin is the 'left' l_ side, its end is the 'right' r_ side # (not using begin / end not to confuse end with endpoint ...) distance_threshold = NodeLookupRadius.value / self.calibration cutoff_radius = NodeLookupCutoffRadius.value / self.calibration for pathlet in pf.pathlets: pathlet_length = calculate_length(pathlet) l_side = pathlet[0] r_side = pathlet[-1] # experiment l_test_distance, l_test_index = endpoint_tree.query(l_side, k=1) if l_test_distance < distance_threshold: l_is_end = True else: # original code l_is_end = pf.endpoints_map[l_side[0], l_side[1]] # experiment r_test_distance, r_test_index = endpoint_tree.query(r_side, k=1) if r_test_distance < distance_threshold: r_is_end = True else: # original code r_is_end = pf.endpoints_map[r_side[0], r_side[1]] l_index_shift = endpoint_shift if l_is_end else junction_shift r_index_shift = endpoint_shift if r_is_end else junction_shift l_tree = endpoint_tree if l_is_end else junction_tree r_tree = endpoint_tree if r_is_end else junction_tree # first tuple value would be distance, but we don't care try: l_distance, l_index = l_tree.query(l_side, k=1) r_distance, r_index = r_tree.query(r_side, k=1) except AttributeError: continue if l_distance > cutoff_radius or r_distance > cutoff_radius: # probably does not happen continue adjacency_left_index = l_index + l_index_shift adjacency_right_index = r_index + r_index_shift adjacency[adjacency_left_index, adjacency_right_index] = pathlet_length adjacency[adjacency_right_index, adjacency_left_index] = pathlet_length self.junction_shift = junction_shift self.endpoint_shift = endpoint_shift = data self.endpoint_tree = endpoint_tree self.junction_tree = junction_tree self.endpoint_tree_data = endpoint_tree_data self.junction_tree_data = junction_tree_data self.adjacency = adjacency self.every_endpoint = range(self.endpoint_shift, self.junction_shift) self.every_junction = range(self.junction_shift, self.junction_shift + len(self.junction_tree_data)) cleanup_graph_after_creation = True if cleanup_graph_after_creation: self.cleanup_adjacency() self.adjacency = self.adjacency.tocsr() self.generate_derived_data()
[docs] def cleanup_adjacency(self): """ Cleans up the adjacency matrix after alterations on the node level have been performed. :return: """ non_empty_mask = (self.adjacency.getnnz(axis=0) + self.adjacency.getnnz(axis=1)) > 0 # noinspection PyTypeChecker empty_indices, = np.where(~non_empty_mask) # if this ever becomes multi threaded, we should lock the trees now # endpoint_tree, junction_tree = None, None e_length = non_empty_mask[:self.junction_shift].sum() j_length = non_empty_mask[self.junction_shift:].sum() total_length = e_length + j_length self.junction_shift = e_length self.endpoint_shift = 0 =[non_empty_mask] self.endpoint_tree_data =[:e_length] self.junction_tree_data =[e_length:] if e_length > 0: self.endpoint_tree = KDTree(self.endpoint_tree_data) else: self.endpoint_tree = None if j_length > 0: self.junction_tree = KDTree(self.junction_tree_data) else: self.junction_tree = None new_adjacency = lil_matrix((total_length, total_length), dtype=self.adjacency.dtype) coo = self.adjacency.tocoo() for n, m, value in zip(coo.row, coo.col, npos, = np.where(n >= empty_indices) mpos, = np.where(m >= empty_indices) npos = 0 if len(npos) == 0 else npos[-1] + 1 mpos = 0 if len(mpos) == 0 else mpos[-1] + 1 new_adjacency[n-npos, m-mpos] = value self.adjacency = new_adjacency self.every_endpoint = range(self.endpoint_shift, self.junction_shift) self.every_junction = range(self.junction_shift, self.junction_shift + len(self.junction_tree_data))
[docs] def generate_derived_data(self): """ Generates derived data from the current adjacency matrix. Derived data are shortest paths, as well as connected components. :return: """ self.shortest_paths, self.predecessors = shortest_path(self.adjacency, return_predecessors=True) self.shortest_paths_num = shortest_path(self.adjacency, unweighted=True) self.connected_components_count, self.connected_components = connected_components(self.adjacency)
@property def cycles(self): """ Detects whether a cycle exists in the graph. :return: """ if self._cycles is None: g = self.get_networkx_graph() try: find_cycle(g) self._cycles = True except NetworkXNoCycle: self._cycles = False return self._cycles
[docs] def get_path(self, start_node, end_node): """ Walks from start_node to end_node in the graph and returns the list of nodes (including both). :param start_node: :param end_node: :return: """ predecessor = end_node path = [predecessor] while predecessor != start_node: predecessor = self.predecessors[start_node, predecessor] path.append(predecessor) return path[::-1]
[docs] def is_endpoint(self, i): """ Returns whether node i is an endpoint. :param i: :return: """ return i in self.every_endpoint
[docs] def is_junction(self, i): """ Returns whether node i is a junction. :param i: :return: """ return i in self.every_junction
[docs] def get_connected_nodes(self, some_node): """ Get all nodes which are (somehow) connected to node some_node. :param some_node: :return: """ label = self.connected_components[some_node] return np.where(self.connected_components[self.connected_components == label])[0]
[docs] def track(self, successor): """ Tracks nodes on this frame to nodes on a successor frame. :param successor: :return: """ delta_t = (successor.timepoint - self.timepoint) / (60.0*60.0) junction_shift_radius = (NodeTrackingJunctionShiftRadius.value / self.calibration) * delta_t endpoint_shift_radius = (NodeTrackingEndpointShiftRadius.value / self.calibration) * delta_t ## self_len = len( successor_len = len( self_to_successor = np.zeros(self_len, dtype=int) successor_to_self = np.zeros(successor_len, dtype=int) self_to_successor[:] = -1 successor_to_self[:] = -1 self_to_successor_alternatives = [[]] * self_len if self.junction_tree is not None and successor.junction_tree is not None: junction_mapping = self.junction_tree.query_ball_tree(successor.junction_tree, junction_shift_radius) else: junction_mapping = [] if self.endpoint_tree is not None and successor.endpoint_tree is not None: endpoint_mapping = self.endpoint_tree.query_ball_tree(successor.endpoint_tree, endpoint_shift_radius) else: endpoint_mapping = [] # print(self.timepoint, get_or_else(lambda:, endpoint_mapping) for self_hit, n in enumerate(chain(endpoint_mapping, junction_mapping)): if len(n) == 0: n = -1 ordered_n = [] else: search_point =[self_hit] hit_points = np.array([[h] for h in n]) distances = np.sqrt(((hit_points - search_point) ** 2).sum(axis=1)) indexed = np.c_[distances, n] ordered_n = [int(nn[1]) for nn in sorted(indexed, key=lambda t: t[0])] min_distance = np.argmin(distances) n = n[min_distance] if self_hit > self.junction_shift: n += successor.junction_shift self_to_successor_alternatives[self_hit] = ordered_n self_to_successor[self_hit] = n successor_to_self[n] = self_hit self.self_to_successor = self_to_successor # this is mainly used self.successor_to_self = successor_to_self self.self_to_successor_alternatives = self_to_successor_alternatives
[docs] def get_networkx_graph(self, with_z=0, return_positions=False): """ Convert the adjacency matrix based internal graph representation to a networkx graph representation. Positions are additionally set based upon the pixel coordinate based positions of the nodes. :param with_z: Whether z values should be set based upon the timepoint the nodes appear on :param return_positions: Whether positions should be returned jointly with the graph :return: """ g = from_scipy_sparse_matrix(self.adjacency) positions = {} for n, pos in enumerate( positions[n] = (float(pos[1]), float(pos[0])) g.node[n]['x'] = float(pos[1]) g.node[n]['y'] = float(pos[0]) if with_z > 0: g.node[n]['z'] = float(self.timepoint * with_z) positions[n] = (float(pos[1]), float(pos[0]), float(self.timepoint * with_z)) if not return_positions: return g else: return g, positions