"""Build Spektral-style graph tensors from ForSys / Surface Evolver workflows."""
__all__ = ['se_output_to_graph', 'skeleton_to_graph']
import random
from collections import defaultdict
from itertools import combinations
from logging import getLogger
from pathlib import Path
from time import perf_counter
from typing import Any
import forsys as fs
import numpy as np
from scipy.sparse import dok_matrix
from scipy.spatial.distance import euclidean as euclidean_dist
from .utils.data_utils import ConnectedNodes, resample_vertices
from .utils.plotting import plot_with_force
StrPath = str | Path
log = getLogger(__name__)
NODE_FEATURES = ['arc_length', 'chord_length', 'cell1_area', 'cell2_area', 'cell1_per', 'cell2_per']
def _get_neighboring_vertices(vertex: fs.vertex.Vertex,
lattice: fs.surface_evolver.SurfaceEvolver
) -> list[fs.vertex.Vertex]:
neighbors = []
# loop over edges connected to this vertex
for eid in vertex.ownEdges:
edge = lattice.edges[eid]
# add the other vertex that's connected to each edge
neighbor = edge.v1 if (edge.v1.id != vertex.id) else edge.v2
neighbors.append(neighbor)
return neighbors
def _jitter_vertices(lattice: fs.surface_evolver.SurfaceEvolver,
scale: float,
skip_big_edge_vertices: bool = True,
random_seed: int = None
) -> None:
rng = random.Random(random_seed)
for v in lattice.vertices.values():
if skip_big_edge_vertices and (len(v.ownEdges) != 2):
continue # only jitter vertices connected by 2 edges
nvs = _get_neighboring_vertices(v, lattice)
# calculate additive noise stdev as the scale factor multiplied
# by the distance between the vertex and its closest neighbor
stdev = scale * min(euclidean_dist([v.x, v.y], [nv.x, nv.y]) for nv in nvs)
# add gaussian noise to each coordinate
v.x += rng.normalvariate(0, stdev)
v.y += rng.normalvariate(0, stdev)
def _forsys_frame_to_graph(frame: fs.frames.Frame,
include_targets: bool,
edge_n_vertices: int,
apply_savgol_filter: bool,
include_forsys_predictions: bool,
forsys_solve_method: str,
render_plots: bool,
plots_dir: StrPath,
return_timers: bool,
tag_cell_interfaces: list | None,
raise_if_gt_is_zero: bool,
load_time: float,
plots_prefix: str,
return_frame: bool,
verbose: bool
) -> dict[str, Any]:
st = perf_counter()
# optionally apply filter to remove noise
if apply_savgol_filter:
frame.filter_edges(method='SG')
n_big_edges = len(frame.internal_big_edges)
# predict tensions with forsys using this frame only
forsys_pred_time = None
if include_forsys_predictions:
if verbose:
log.info('Predicting tensions with forsys...')
ft = perf_counter()
forsys_engine = fs.ForSys({0: frame})
forsys_engine.build_force_matrix(when=0, angle_limit=np.inf)
forsys_engine.solve_stress(when=0, allow_negatives=False, method=forsys_solve_method)
forsys_pred_time = perf_counter() - ft
if verbose:
log.info('Extracting features...')
# big-edge vertex tracker to help us build the adjacency matrix later
adj_tracker = defaultdict(set)
# container for the big-edge points
node_points = np.empty((n_big_edges, edge_n_vertices, 2), dtype=np.float32)
# node features
node_features = np.empty((n_big_edges, len(NODE_FEATURES)), dtype=np.float32)
# ground-truths and forsys predictions
ground_truth = np.empty(n_big_edges, dtype=np.float32) if include_targets else None
forsys_preds = np.empty(n_big_edges, dtype=np.float32) if include_forsys_predictions else None
# boolean array to tag specific big-edges
tagged_nodes = np.empty(n_big_edges, dtype=np.bool) if tag_cell_interfaces else None
to_tag = {tuple(sorted(cell_pair)) for cell_pair in tag_cell_interfaces} if tag_cell_interfaces else None
bigi = 0
for bigedge in frame.big_edges.values():
if bigedge.external:
continue
if include_targets:
assert (bigedge.gt > 0) or (not raise_if_gt_is_zero), f'Big-edge #{bigi} gt tension is {bigedge.gt:.4f}'
ground_truth[bigi] = bigedge.gt
if include_forsys_predictions:
tension = bigedge.tension
if tension < 0:
log.warning(f'ForSys predicted tension for big-edge #{bigi} is negative: {tension:.4f}.')
forsys_preds[bigi] = tension
smedges = bigedge.edges # list of small-edges ids
n_smedges = len(smedges)
# initialize the array of vertex coordinates
n_points = n_smedges + 1
points = np.empty((n_points, 2), dtype=np.float32)
c1a = c2a = c1p = c2p = None
is_tagged = False
for smalli, edge_id in enumerate(smedges):
# get edge
edge = frame.edges[edge_id]
# get the coordinates of the segment
v1 = [edge.v1.x, edge.v1.y]
v2 = [edge.v2.x, edge.v2.y]
# store vertex #1 coordinates
points[smalli] = v1
# in the first one, store vertex_id and small edge coords
if smalli == 0:
assert len(edge.v1.ownEdges) > 2, f'Expecting tip of Big-edge #{bigi} to be connected to at least 3 ' \
f'edges, got {edge.v1.ownEdges}'
adj_tracker[bigi].add(edge.v1.id)
# get adjacent cell features from the second vertex of the first small-edge (not a triple junction)
if n_smedges >= 2:
cells = edge.v2.ownCells
else: # unless it has only 1 small age and in that case we get cells that are common to both
cells = list(set(edge.v1.ownCells).intersection(edge.v2.ownCells))
assert len(cells) <= 2, f'Expecting a maximum of 2 ownCells for vertex {edge.v2}, got {cells}.'
c1 = frame.cells[cells[0]]
c1a = abs(c1.get_area())
c1p = c1.get_perimeter()
if len(cells) > 1:
c2 = frame.cells[cells[1]]
c2a = abs(c2.get_area())
c2p = c2.get_perimeter()
else: # if we know of only one cell, replicate its data for the second one
c2a, c2p = c1a, c1p
# check if we need to tag this big-edge
if (to_tag is not None) and (tuple(sorted(cells)) in to_tag):
is_tagged = True
# in the last one also store vertex #2 coordinates, vertex_id and small edge coords
if smalli == (n_smedges - 1):
assert len(edge.v2.ownEdges) > 2, f'Expecting tip of Big-edge #{bigi} to be connected to at least 3 ' \
f'edges, got {edge.v2.ownEdges}'
points[-1] = v2 # store vertex #2 coordinates
adj_tracker[bigi].add(edge.v2.id)
if n_points != edge_n_vertices:
# use spline interpolation to create intermediate points
points = resample_vertices(points, target_length=edge_n_vertices, spline_order=2)
operation = 'upsampled' if (edge_n_vertices > n_points) else 'downsampled'
log.info(f'Big-edge #{bigi} points were {operation} from {n_points} '
f'to {edge_n_vertices} using 2nd order spline interpolation.')
# calculate the length as if the edge was a straight line a.k.a. distance between junctions
chordlen = euclidean_dist(points[0], points[-1])
# calculate the arc-length as the sum of all small-edge lengths
arclen = sum([euclidean_dist(points[i], points[i + 1]) for i in range(len(points) - 1)])
# populate arrays
nf = [arclen, chordlen, c1a, c2a, c1p, c2p]
assert all((v and v > 0) for v in nf), f'Invalid values in node features: {nf}'
node_features[bigi] = np.array(nf, dtype=np.float32)
node_points[bigi] = points
if tagged_nodes is not None:
tagged_nodes[bigi] = is_tagged
bigi += 1
# TODO: for sure there are more efficient ways to exclude disconnected nodes. This works for now.
adjacency = []
connected_set = set()
for i, j in combinations(range(n_big_edges), 2):
intersect = adj_tracker[i].intersection(adj_tracker[j])
if intersect:
assert len(intersect) == 1, f'A pair of big edges ({i} and {j}) share {len(intersect)} vertices (WTF?)'
adjacency.append((i, j))
connected_set.update([i, j])
# indices of all connected nodes
connected = sorted(connected_set)
idx_old_to_new = {n: i for i, n in enumerate(connected)} # indexer for node_features, ground_truth and forsys_preds
# initialize adjacency matrix
adj_mat = dok_matrix((len(connected),) * 2, dtype=bool)
# edge features (store node indices as they get added to the adj. mat. so we can use them to order edge feats later)
edge_index_buffer, edge_features_buffer = [], []
# populate the adjacency matrix and edge features arrays
for i, j in adjacency:
# make sure that edge features for i->j are always i> oooo[X]oooo >j, where [X] is the junction point (removed)
if np.array_equal(node_points[i][0], node_points[j][0]): # Xoooo Xoooo
edge_points_ij = np.concatenate((node_points[i][::-1][:-1], node_points[j][1:]))
joint_coords = node_points[i][0]
elif np.array_equal(node_points[i][-1], node_points[j][0]): # ooooX Xoooo
edge_points_ij = np.concatenate((node_points[i][:-1], node_points[j][1:]))
joint_coords = node_points[i][-1]
elif np.array_equal(node_points[i][-1], node_points[j][-1]): # ooooX ooooX
edge_points_ij = np.concatenate((node_points[i][:-1], node_points[j][::-1][1:]))
joint_coords = node_points[i][-1]
elif np.array_equal(node_points[i][0], node_points[j][-1]): # Xoooo ooooX
edge_points_ij = np.concatenate((node_points[i][::-1][:-1], node_points[j][::-1][1:]))
joint_coords = node_points[i][0]
else:
raise ValueError(f'WTF?\n{node_points[i]}\n{node_points[j]}')
# move points so the joint ends up at the origin of coordinates.
# the junction point is a constant feature (always [0, 0]), that's why we excluded it from the tensor beforehand
edge_points_ij -= joint_coords
# the features of the edge j->i are the sequence of points in the features of the edge i->j in reverse order
edge_points_ji = edge_points_ij[::-1]
# add edge features to the matrix
edge_features_buffer.extend([edge_points_ij, edge_points_ji])
# populate adj matrix and keep track of the rows/columns in order to sort edge features later
i_, j_ = idx_old_to_new[i], idx_old_to_new[j]
adj_mat[i_, j_] = True
adj_mat[j_, i_] = True
edge_index_buffer.extend([[i_, j_], [j_, i_]])
if verbose:
log.info('Normalizing data and running additional validations...')
# filter out disconnected nodes from the original node_features matrix
node_features = node_features[connected]
if tagged_nodes is not None:
tagged_nodes = tagged_nodes[connected]
# order edge features in the way spektral expects them (as they appear in the adj. mat. sorted in row-major order)
edge_index = np.array(edge_index_buffer)
sort_idx = np.lexsort(np.flipud(edge_index.T)) # row-major order
edge_features = np.array(edge_features_buffer, dtype=np.float32)[sort_idx]
# rescale coordinates: divide by the maximum edge arc-length so all coords get rescaled to the range [-1, 1]
max_length = node_features[:, 0].max()
edge_features /= max_length
# also scale lengths, cell areas and cell perimeters to the range [0, 1]
node_features[:, 0:2] /= max_length # chord-length and arc-length
node_features[:, 2:4] /= node_features[:, 2:4].max() # cell areas
node_features[:, 4:6] /= node_features[:, 4:6].max() # cell perimeters
# convert adjacency matrix to CSR
adj_mat = adj_mat.tocsr()
# make sure that all nodes in the graph have at least one neighbor
CN = ConnectedNodes(adj_mat)
CN.assert_all_connected()
assert edge_features.shape[0] == adj_mat.nnz, 'Expecting number of entries in the adj matrix to match n_edges'
assert node_features.shape[0] == CN.get_count(), 'Expecting n_nodes to be equal to n_connected_nodes'
assert (np.min(node_features) >= 0) and (np.max(node_features) <= 1), 'Node feats not in range [0, 1]'
assert (np.min(edge_features) >= -1) and (np.max(edge_features) <= 1), 'Edge feats not in range [-1, 1]'
# initialize the outputs
out = {'a': adj_mat, 'x': node_features, 'e': edge_features}
# return disconnected nodes if any
n_removed = n_big_edges - len(connected_set)
if n_removed > 0:
log.warning(f'{n_removed} nodes were excluded for being disconnected from the graph.')
out['removed_nodes'] = np.array(sorted(set(range(n_big_edges)) - connected_set), dtype=int)
# return tagged nodes if any
if (tagged_nodes is not None) and tagged_nodes.any():
out['tagged_nodes'] = np.flatnonzero(tagged_nodes)
# get a dataframe with predicted tensions and ground truths from forsys.frames.Frame object
fs_frame_tensions = frame.get_tensions()
if include_targets:
assert np.allclose(ground_truth, fs_frame_tensions['gt'].values), 'ForSys ground-truths do not match ours.'
# filter out tensions of disconnected nodes (we cannot predict them)
targets = ground_truth[connected]
assert targets.shape[0] == node_features.shape[0], 'Expecting n_targets to be equal to n_nodes'
# rescale target vector to have mean=1
targets_mean = targets.mean()
if targets_mean > 0: # targets mean can only be zero if all tensions are 0 (ground-truth not available)
targets /= targets_mean
out['y'] = targets
if include_forsys_predictions:
assert np.allclose(forsys_preds, fs_frame_tensions['stress'].values), 'ForSys tensions do not match ours.'
# filter out predictions of disconnected nodes
forsys_preds_f = forsys_preds[connected]
assert forsys_preds_f.shape[0] == node_features.shape[0], 'Expecting n_forsys_preds to be equal to n_nodes'
assert np.all(forsys_preds_f >= 0), f'Negative values in filtered forsys preds: {forsys_preds_f}'
if np.any(forsys_preds_f == 0):
log.warning('Found zeros in ForSys predictions which will be ignored in mean normalization. '
'These values should be masked-out when calculating metrics to obtain accurate results.')
# re-normalize forsys predictions after removing disconnected nodes (make sure to ignore zeros in this step)
out['forsys_preds'] = forsys_preds_f / forsys_preds_f[forsys_preds_f > 0].mean()
# end the timer, the rest is just for debugging purposes
total_time = perf_counter() - st + load_time
if return_timers:
out.update({'total_time': total_time, 'load_time': load_time})
if include_forsys_predictions:
out['forsys_pred_time'] = forsys_pred_time
# TODO: move this plotting step to its own function. Had to modify func in forsys (careful with future versions)
if render_plots and (include_targets or include_forsys_predictions):
if verbose:
log.info('Plotting...')
plots_base = Path(plots_dir)
if include_targets:
# plot with ground-truth tensions
plot_with_force(frame, filename=str(plots_base / f'{plots_prefix}_gt'), force_to_plot='gt')
if include_forsys_predictions:
# plot with tensions predicted by forsys
plot_with_force(frame, filename=str(plots_base / f'{plots_prefix}_forsys'), force_to_plot='tension')
if return_frame:
out['frame'] = frame
return out
[docs]
def se_output_to_graph(src_file: StrPath,
*,
include_targets: bool = True,
edge_n_vertices: int = 9,
apply_savgol_filter: bool = False,
include_forsys_predictions: bool = True,
forsys_solve_method: str = 'default',
render_plots: bool = False,
plots_dir: StrPath = './debug',
return_timers: bool = True,
tag_cell_interfaces: list | None = None,
jitter_kwargs: dict | None = None,
raise_if_gt_is_zero: bool = True,
debug_plots_prefix: str | None = None,
return_frame: bool = False,
verbose: bool = True
) -> dict[str, Any]:
"""Load a Surface Evolver dump, build a ForSys frame, and return graph arrays.
Parameters
----------
src_file
Path to the Surface Evolver output file.
include_targets
Whether ground-truth tensions are present and should populate ``gt`` on the frame.
edge_n_vertices
Number of vertices sampled along each lattice edge for feature extraction.
apply_savgol_filter
If ``True``, smooth resampled vertex coordinates with a Savitzky–Golay filter.
include_forsys_predictions
Whether to run ForSys inference for auxiliary predictions in the output dict.
forsys_solve_method
Solver label passed to ForSys tension recovery.
render_plots
If ``True``, write debug plots under ``plots_dir``.
plots_dir
Directory for optional debug figures.
return_timers
Include timing fields in the returned dictionary.
tag_cell_interfaces
Optional list controlling interface tagging between cells.
jitter_kwargs
Optional dict of jitter parameters forwarded to vertex jittering.
raise_if_gt_is_zero
If ``True``, validate non-degenerate ground-truth tensions when present.
debug_plots_prefix
Filename prefix for plots; defaults to the stem of ``src_file``.
return_frame
If ``True``, include the constructed ``frame`` object under key ``'frame'``.
verbose
Enable logging of major steps.
Returns
-------
dict
Keys typically include ``'a'``, ``'x'``, ``'e'``, optional ``'y'``, timing keys,
and optional ForSys prediction arrays depending on flags.
"""
log.debug('Loading data in forsys...')
src_path = Path(src_file)
# start main timer
st = perf_counter()
# load data from SE output in forsys
lattice = fs.surface_evolver.SurfaceEvolver(str(src_path))
if jitter_kwargs:
_jitter_vertices(lattice, **jitter_kwargs)
frame = fs.frames.Frame(0, lattice.vertices, lattice.edges, lattice.cells, time=0, gt=include_targets)
load_time = perf_counter() - st
# define debug plot filenames prefixes from the name of the input file
if debug_plots_prefix is None:
debug_plots_prefix = src_path.stem if src_path.suffix else src_path.name
# extract features and build the graph
return _forsys_frame_to_graph(frame, include_targets, edge_n_vertices, apply_savgol_filter,
include_forsys_predictions, forsys_solve_method, render_plots, plots_dir,
return_timers, tag_cell_interfaces, raise_if_gt_is_zero, load_time,
debug_plots_prefix, return_frame, verbose)
[docs]
def skeleton_to_graph(src_file: StrPath,
gt_file: StrPath | None = None,
*,
mirror_y: bool = False,
edge_n_vertices: int = 9,
fixed_ne: int | None = None,
apply_savgol_filter: bool = False,
include_forsys_predictions: bool = True,
forsys_solve_method: str = 'default',
render_plots: bool = False,
plots_dir: StrPath = './debug',
return_timers: bool = True,
tag_cell_interfaces: list | None = None,
raise_if_gt_is_zero: bool = True,
debug_plots_prefix: str | None = None,
return_frame: bool = False,
verbose: bool = True
) -> dict[str, Any]:
"""Load a skeleton ``.tif``, build a mesh in ForSys, optionally attach myosin GT, and return graph tensors.
Parameters
----------
src_file
Path to a binary skeleton image (``.tif``).
gt_file
Optional myosin intensity image (``.tif``) used to assign ground-truth tensions.
mirror_y
Passed to ForSys skeleton loading (flip image rows).
edge_n_vertices
Target number of vertices per edge for mesh generation.
fixed_ne
If set, overrides derived ``ne`` for ``generate_mesh``.
apply_savgol_filter
Whether to smooth resampled vertices with Savitzky–Golay filtering.
include_forsys_predictions
Whether to include ForSys baseline predictions in the output.
forsys_solve_method
ForSys solver label for tension recovery.
render_plots
Write debug plots when ``True``.
plots_dir
Output directory for debug plots.
return_timers
Attach timing metadata to the result dict.
tag_cell_interfaces
Optional interface tagging list for feature extraction.
raise_if_gt_is_zero
Validate non-zero ground truth when ``gt_file`` is provided.
debug_plots_prefix
Plot filename prefix; defaults to skeleton stem.
return_frame
Include the ForSys ``frame`` under ``'frame'`` when ``True``.
verbose
Enable progress logging.
Returns
-------
dict
Graph tensors and optional targets / predictions, same style as :func:`se_output_to_graph`.
"""
if verbose:
log.info('Loading data in forsys...')
src_path = Path(src_file)
# start main timer
st = perf_counter()
# load data from skeleton .tif file in forsys
assert src_path.suffix.lower() == '.tif', 'Expecting a .tif file to generate the mesh.'
ne = fixed_ne or (edge_n_vertices - 1)
skeleton = fs.skeleton.Skeleton(str(src_path), mirror_y=mirror_y)
vertices, edges, cells = skeleton.create_lattice()
vertices, edges, cells, _ = fs.virtual_edges.generate_mesh(vertices, edges, cells, ne=ne)
frame = fs.frames.Frame(0, vertices, edges, cells, time=0, gt=False)
if gt_file is not None:
gt_path = Path(gt_file)
assert gt_path.suffix.lower() == '.tif', 'Expecting a .tif file to estimate the ground-truth from pixel intensities.'
gt_tensions = fs.myosin.read_myosin(frame, str(gt_path), layers=2) # layers value used in ForSys paper
frame.assign_gt_tensions_to_big_edges(list(gt_tensions.values()))
load_time = perf_counter() - st
# define debug plot filenames prefixes from the name of the input file
if debug_plots_prefix is None:
debug_plots_prefix = src_path.stem if src_path.suffix else src_path.name
include_targets = gt_file is not None
# extract features and build the graph
return _forsys_frame_to_graph(frame, include_targets, edge_n_vertices, apply_savgol_filter,
include_forsys_predictions, forsys_solve_method, render_plots, plots_dir,
return_timers, tag_cell_interfaces, raise_if_gt_is_zero, load_time,
debug_plots_prefix, return_frame, verbose)