"""
Matplotlib based plotting
"""
import logging
import pathlib
import shutil
import subprocess
import tempfile
import warnings
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from ipywidgets import interactive
from matplotlib import colormaps
from matplotlib.collections import LineCollection, PatchCollection, PolyCollection
from matplotlib.patches import Arc, FancyArrow, PathPatch
from matplotlib.path import Path
from ..config.draw import sheet_spec
from ..utils.utils import get_sub_eptm, spec_updater
COORDS = ["x", "y"]
log = logging.getLogger(__name__)
[docs]def browse_history(
history,
coords=["x", "y"],
start=None,
stop=None,
size=None,
draw_func=None,
margin=5,
**draw_kwds,
):
"""Returns a browser widget with 2D plots of the epithelium"""
if draw_func is None:
if draw_kwds.get("mode") in ("quick", None):
draw_func = quick_edge_draw
else:
draw_func = sheet_view
times = history.slice(start, stop, size)
size = times.size
x, y = coords = draw_kwds.get("coords", history.sheet.coords[:2])
sheet0 = history.retrieve(0)
bounds = sheet0.vert_df[coords].describe().loc[["min", "max"]]
delta = (bounds.loc["max"] - bounds.loc["min"]).max()
margin = delta * margin / 100
xlim = bounds.loc["min", x] - margin, bounds.loc["max", x] + margin
ylim = bounds.loc["min", y] - margin, bounds.loc["max", y] + margin
def set_frame(i=0):
t = times[i]
sheet = history.retrieve(t)
fig = plt.figure(2)
ax = fig.subplots()
fig, ax = draw_func(sheet, ax=ax, **draw_kwds)
ax.set(xlim=xlim, ylim=ylim)
plt.show()
widget = interactive(set_frame, i=(0, size - 1))
widget.layout.height = "500px"
return widget
[docs]def create_gif(
history,
output,
num_frames=None,
interval=None,
draw_func=None,
margin=5,
**draw_kwds,
):
"""Creates an animated gif of the recorded history.
You need imagemagick on your system for this function to work.
Parameters
----------
history : a :class:`tyssue.History` object
output : path to the output gif file
num_frames : int, the number of frames in the gif
interval : tuples, define begin and end frame of the gif
draw_func : a drawing function
this function must take a `sheet` object as first argument
and return a `fig, ax` pair. Defaults to quick_edge_draw
(aka sheet_view with quick mode)
margin : int, the graph margins in percents, default 5
if margin is -1, let the draw function decide
**draw_kwds are passed to the drawing function
"""
if draw_func is None:
draw_func = sheet_view
graph_dir = pathlib.Path(tempfile.mkdtemp())
x, y = coords = draw_kwds.get("coords", history.sheet.coords[:2])
sheet0 = history.retrieve(0)
bounds = sheet0.vert_df[coords].describe().loc[["min", "max"]]
delta = (bounds.loc["max"] - bounds.loc["min"]).max()
margin = delta * margin / 100
xlim = bounds.loc["min", x] - margin, bounds.loc["max", x] + margin
ylim = bounds.loc["min", y] - margin, bounds.loc["max", y] + margin
if interval is None:
start, stop = None, None
else:
start, stop = interval[0], interval[1]
for i, (t, sheet) in enumerate(history.browse(start, stop, num_frames)):
try:
fig, ax = draw_func(sheet, **draw_kwds)
except Exception as e:
print(f"Droped frame {i}")
print(e)
continue
if isinstance(ax, plt.Axes) and margin >= 0:
ax.set(xlim=xlim, ylim=ylim)
fig.savefig(graph_dir / f"movie_{i:04d}.png")
plt.close(fig)
try:
subprocess.run(["convert", (graph_dir / "movie_*.png").as_posix(), output])
except Exception as e:
print(
"Converting didn't work, make sure imagemagick is available on your system"
)
raise e
finally:
shutil.rmtree(graph_dir)
[docs]def sheet_view(sheet, coords=COORDS, ax=None, cbar_axis=None, **draw_specs_kw):
"""Base view function, parametrizable
through draw_secs
The default sheet_spec specification is:
{
"edge": {
"visible": true,
"width": 0.5,
"head_width": 0.0,
"length_includes_head": true,
"shape": "right",
"color": "#2b5d0a",
"alpha": 0.8,
"zorder": 1,
"colormap": "viridis"
},
"vert": {
"visible": false,
"s": 100,
"color": "#000a4b",
"alpha": 0.3,
"zorder": 2
},
"grad": {
"color":"#000a4b",
"alpha":0.5,
"width":0.04
},
"face": {
"visible": false,
"color":"#8aa678",
"alpha": 1.0,
"zorder": -1
},
"axis": {
"autoscale": true,
"color_bar": false,
"color_bar_cmap":"viridis",
"color_bar_range":false,
"color_bar_label":false,
"color_bar_target":"face"
}
}
Note
----
Important note for quantitative colormap plots: make sure to normalize your
values before getting the colors using
draw_specs["face"]["color"] = cmap(pandas_holding_quantity_of_interest)
For each plot normalize with respect to the current values
(max and min) such that they lie between and including 0 to 1.
Note that if you want to keep a constant colorbar range you have
to choose the normalization to match the max and min of the color
bar range you chose.
"""
draw_specs = sheet_spec()
spec_updater(draw_specs, draw_specs_kw)
if (ax is None) or (cbar_axis is None):
fig = plt.figure()
else:
fig = ax.get_figure()
grid0 = plt.GridSpec(10, 10)
grid0.update(wspace=0.0)
ax = fig.add_subplot(grid0[:, :9])
vert_spec = draw_specs["vert"]
if vert_spec["visible"]:
ax = draw_vert(sheet, coords, ax, **vert_spec)
edge_spec = draw_specs["edge"]
if edge_spec["visible"]:
ax = draw_edge(sheet, coords, ax, **edge_spec)
face_spec = draw_specs["face"]
if face_spec["visible"]:
ax = draw_face(sheet, coords, ax, **face_spec)
axis_spec = draw_specs.get("axis", {})
if axis_spec.get("autoscale"):
ax.autoscale()
ax.set_aspect("equal")
else:
ax.set_xlim(axis_spec["x_min"], axis_spec["x_max"])
ax.set_ylim(axis_spec["y_min"], axis_spec["y_max"])
ax.set_aspect("equal")
if not axis_spec.get("color_bar"):
return fig, ax
else:
cbar_axis = fig.add_subplot(grid0[:, 9])
cmap = colormaps[axis_spec.get("color_bar_cmap")]
if not axis_spec.get("color_bar_range"):
warnings.warn(
"""Since the quanity of interest should be normalized
to pick face colours, color bar range should always be specified
according to the normalization used. Default 0 to 1 range is used.
"""
)
norm = mpl.colors.Normalize(0.0, 1.0)
else:
norm = mpl.colors.Normalize(
vmin=axis_spec.get("color_bar_range")[0],
vmax=axis_spec.get("color_bar_range")[1],
)
cb1 = mpl.colorbar.ColorbarBase(
cbar_axis, cmap=cmap, norm=norm, orientation="vertical"
)
if not axis_spec.get("color_bar_label"):
cb1.set_label("a.u.")
else:
cb1.set_label(axis_spec.get("color_bar_label"))
return fig, ax
[docs]def draw_face(sheet, coords, ax, **draw_spec_kw):
"""Draws epithelial sheet polygonal faces in matplotlib
Keyword values can be specified at the element
level as columns of the sheet.face_df
"""
draw_spec = sheet_spec()["face"]
draw_spec.update(**draw_spec_kw)
collection_specs = parse_face_specs(draw_spec, sheet)
if "visible" in sheet.face_df.columns:
edges = sheet.edge_df[sheet.upcast_face(sheet.face_df["visible"])].index
if edges.shape[0]:
_sheet = get_sub_eptm(sheet, edges)
sheet = _sheet
color = collection_specs["facecolors"]
if isinstance(color, np.ndarray):
faces = sheet.face_df["face_o"].values.astype(np.uint32)
collection_specs["facecolors"] = color.take(faces, axis=0)
else:
warnings.warn("No face is visible")
if not sheet.is_ordered:
sheet_ = sheet.copy()
sheet_.reset_index(order=True)
polys = sheet_.face_polygons(coords)
else:
polys = sheet.face_polygons(coords)
p = PolyCollection(polys, closed=True, **collection_specs)
ax.add_collection(p)
return ax
[docs]def parse_face_specs(face_draw_specs, sheet):
collection_specs = {}
color = face_draw_specs.get("color")
if callable(color):
color = color(sheet)
face_draw_specs["color"] = color
if color is None:
return {}
elif isinstance(color, str):
collection_specs["facecolors"] = color
elif hasattr(color, "__len__"):
collection_specs["facecolors"] = _face_color_from_sequence(
face_draw_specs, sheet
)
if "alpha" in face_draw_specs:
collection_specs["alpha"] = face_draw_specs["alpha"]
return collection_specs
def _face_color_from_sequence(face_spec, sheet):
color_ = face_spec["color"]
cmap = colormaps[face_spec.get("colormap", "viridis")]
color_min, color_max = face_spec.get("color_range", (color_.min(), color_.max()))
if color_.shape in [(sheet.Nf, 3), (sheet.Nf, 4)]:
return color_
elif color_.shape == (sheet.Nf,):
if np.ptp(color_) < 1e-10:
log.info("Attempting to draw a colormap " "with a uniform value")
return np.ones((sheet.Nf, 3)) * 0.5
normed = (color_ - color_min) / (color_max - color_min)
return cmap(normed)
else:
raise ValueError(
"shape of `face_spec['color']` must be either (Nf, 3), (Nf, 4) or (Nf,)"
)
[docs]def draw_vert(sheet, coords, ax, **draw_spec_kw):
"""Draw junction vertices in matplotlib."""
draw_spec = sheet_spec()["vert"]
draw_spec.update(**draw_spec_kw)
x, y = coords
if "z_coord" in sheet.vert_df.columns:
pos = sheet.vert_df.sort_values("z_coord")[coords]
else:
pos = sheet.vert_df[coords]
ax.scatter(pos[x], pos[y], **draw_spec_kw)
return ax
[docs]def draw_edge(sheet, coords, ax, **draw_spec_kw):
""""""
draw_spec = sheet_spec()["edge"]
draw_spec.update(**draw_spec_kw)
arrow_specs, collections_specs = _parse_edge_specs(draw_spec, sheet)
dx, dy = ("d" + c for c in coords)
sx, sy = ("s" + c for c in coords)
tx, ty = ("t" + c for c in coords)
if draw_spec.get("head_width"):
app_length = (
np.hypot(sheet.edge_df[dx], sheet.edge_df[dy]) * sheet.edge_df.length.mean()
)
patches = [
FancyArrow(*edge[[sx, sy, dx, dy]], **arrow_specs)
for idx, edge in sheet.edge_df[app_length > 1e-6].iterrows()
]
ax.add_collection(
PatchCollection(patches, match_original=False, **collections_specs)
)
else:
segments = sheet.edge_df[[sx, sy, tx, ty]].to_numpy().reshape((-1, 2, 2))
ax.add_collection(LineCollection(segments, **collections_specs))
return ax
def _parse_edge_specs(edge_draw_specs, sheet):
arrow_keys = ["head_width", "length_includes_head", "shape"]
arrow_specs = {
key: val for key, val in edge_draw_specs.items() if key in arrow_keys
}
collection_specs = {}
if arrow_specs.get("head_width"): # draw arrows
color_key = "edgecolors"
else:
color_key = "colors"
if "color" in edge_draw_specs:
if callable(edge_draw_specs["color"]):
edge_draw_specs["color"] = edge_draw_specs["color"](sheet)
if isinstance(edge_draw_specs["color"], str):
collection_specs[color_key] = edge_draw_specs["color"]
elif hasattr(edge_draw_specs["color"], "__len__"):
collection_specs[color_key] = _wire_color_from_sequence(
edge_draw_specs, sheet
)
if "width" in edge_draw_specs:
collection_specs["linewidths"] = edge_draw_specs["width"]
if "alpha" in edge_draw_specs:
collection_specs["alpha"] = edge_draw_specs["alpha"]
return arrow_specs, collection_specs
def _wire_color_from_sequence(edge_spec, sheet):
""""""
color_ = edge_spec["color"]
color_min, color_max = edge_spec.get("color_range", (color_.min(), color_.max()))
cmap = colormaps[edge_spec.get("colormap", "viridis")]
if color_.shape in [(sheet.Nv, 3), (sheet.Nv, 4)]:
return (sheet.upcast_srce(color_) + sheet.upcast_trgt(color_)) / 2
elif color_.shape == (sheet.Nv,):
if np.ptp(color_) < 1e-10:
warnings.warn("Attempting to draw a colormap " "with a uniform value")
return np.ones((sheet.Ne, 3)) * 0.7
if not hasattr(color_, "index"):
color_ = pd.Series(color_, index=sheet.vert_df.index)
color_ = (sheet.upcast_srce(color_) + sheet.upcast_trgt(color_)) / 2
return cmap((color_ - color_min) / (color_max - color_min))
elif color_.shape in [(sheet.Ne, 3), (sheet.Ne, 4)]:
return color_
elif color_.shape == (sheet.Ne,):
if np.ptp(color_) < 1e-10:
warnings.warn("Attempting to draw a colormap " "with a uniform value")
return np.ones((sheet.Nv, 3)) * 0.7
return cmap((color_ - color_min) / (color_max - color_min))
[docs]def quick_edge_draw(sheet, coords=["x", "y"], ax=None, **draw_spec_kw):
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
lines_x, lines_y = _get_lines(sheet, coords)
ax.plot(lines_x, lines_y, **draw_spec_kw)
ax.set_aspect("equal")
return fig, ax
def _get_lines(sheet, coords):
lines_x, lines_y = np.zeros(2 * sheet.Ne), np.zeros(2 * sheet.Ne)
scoords = ["s" + c for c in coords]
tcoords = ["t" + c for c in coords]
if set(scoords + tcoords).issubset(sheet.edge_df.columns):
srce_x, srce_y = sheet.edge_df[scoords].values.T
trgt_x, trgt_y = sheet.edge_df[tcoords].values.T
else:
srce_x, srce_y = sheet.upcast_srce(sheet.vert_df[coords]).values.T
trgt_x, trgt_y = sheet.upcast_trgt(sheet.vert_df[coords]).values.T
lines_x[::2] = srce_x
lines_x[1::2] = trgt_x
lines_y[::2] = srce_y
lines_y[1::2] = trgt_y
# Trick from https://github.com/matplotlib/
# matplotlib/blob/master/lib/matplotlib/tri/triplot.py#L65
lines_x = np.insert(lines_x, slice(None, None, 2), np.nan)
lines_y = np.insert(lines_y, slice(None, None, 2), np.nan)
return lines_x, lines_y
[docs]def plot_forces(
sheet, geom, model, coords, scaling, ax=None, approx_grad=None, **draw_specs_kw
):
"""Plot the net forces at each vertex, with their amplitudes multiplied
by `scaling`. To be clear, this is the oposite of the gradient - grad E.
"""
draw_specs = sheet_spec()
spec_updater(draw_specs, draw_specs_kw)
gcoords = ["g" + c for c in coords]
if approx_grad is not None:
app_grad = approx_grad(sheet, geom, model)
grad_i = (
pd.DataFrame(
index=sheet.vert_df[sheet.vert_df.is_active.astype(bool)].index,
data=app_grad.reshape((-1, len(sheet.coords))),
columns=["g" + c for c in sheet.coords],
)
* scaling
)
else:
grad_i = model.compute_gradient(sheet, components=False) * scaling
grad_i = grad_i.loc[sheet.vert_df["is_active"].astype(bool)]
sheet.vert_df[gcoords] = -grad_i[gcoords] # F = -grad E
if "extract" in draw_specs:
sheet = sheet.extract_bounding_box(**draw_specs["extract"])
if ax is None:
fig, ax = quick_edge_draw(sheet, coords)
else:
fig = ax.get_figure()
arrows = sheet.vert_df[coords + gcoords]
for _, arrow in arrows.iterrows():
ax.arrow(*arrow, **draw_specs["grad"])
return fig, ax
[docs]def plot_scaled_energies(sheet, geom, model, scales, ax=None):
"""Plot scaled energies
Parameters
----------
sheet: a:class: Sheet object
geom: a :class:`Geometry` class
model: a :class:'Model'
scales: np.linspace of float
Returns
-------
fig: a :class:matplotlib.figure.Figure instance
ax: :class:matplotlib.Axes instance, default None
"""
from ..utils import scaled_unscaled
def get_energies():
energies = np.array([e.mean() for e in model.compute_energy(sheet, True)])
return energies
energies = np.array(
[scaled_unscaled(get_energies, scale, sheet, geom) for scale in scales]
)
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
ax.plot(scales, energies.sum(axis=1), "k-", lw=4, alpha=0.3, label="total")
for e, label in zip(energies.T, model.labels):
ax.plot(scales, e, label=label)
ax.legend()
return fig, ax
[docs]def get_arc_data(sheet):
srce_pos = sheet.upcast_srce(sheet.vert_df[sheet.coords])
trgt_pos = sheet.upcast_trgt(sheet.vert_df[sheet.coords])
radius = 1 / sheet.edge_df["curvature"]
e_x = sheet.edge_df["dx"] / sheet.edge_df["length"]
e_y = sheet.edge_df["dy"] / sheet.edge_df["length"]
center_x = (srce_pos.x + trgt_pos.x) / 2 - e_y * (radius - sheet.edge_df["sagitta"])
center_y = (srce_pos.y + trgt_pos.y) / 2 - e_x * (radius - sheet.edge_df["sagitta"])
alpha = sheet.edge_df["arc_chord_angle"]
beta = sheet.edge_df["chord_orient"]
# Ok, I admit a fair amount of trial and
# error to get to the stuff below :-p
rot = beta - np.sign(alpha) * np.pi / 2
theta1 = (-alpha + rot) * np.sign(alpha)
theta2 = (alpha + rot) * np.sign(alpha)
center_data = pd.DataFrame.from_dict(
{
"radius": np.abs(radius),
"x": center_x,
"y": center_y,
"theta1": theta1,
"theta2": theta2,
}
)
return center_data
[docs]def curved_view(sheet, radius_cutoff=1e3):
center_data = get_arc_data(sheet)
fig, ax = sheet_view(sheet, **{"edge": {"visible": False}})
curves = []
for idx, edge in center_data.iterrows():
if edge["radius"] > radius_cutoff:
st = sheet.edge_df.loc[idx, ["srce", "trgt"]]
xy = sheet.vert_df.loc[st, sheet.coords]
patch = PathPatch(Path(xy))
else:
patch = Arc(
edge[["x", "y"]],
2 * edge["radius"],
2 * edge["radius"],
theta1=edge["theta1"] * 180 / np.pi,
theta2=edge["theta2"] * 180 / np.pi,
)
curves.append(patch)
ax.add_collection(PatchCollection(curves, False, **{"facecolors": "none"}))
ax.autoscale()
return fig, ax
[docs]def plot_junction(eptm, edge_index, coords=["x", "y"]):
"""Plots local graph around a junction, for debugging purposes."""
v10, v11 = eptm.edge_df.loc[edge_index, ["srce", "trgt"]]
fig, ax = plt.subplots()
ax.scatter(*eptm.vert_df.loc[[v10, v11], coords].values.T, marker="+", s=300)
v10_out = set(eptm.edge_df[eptm.edge_df["srce"] == v10]["trgt"]) - {v11}
v11_out = set(eptm.edge_df[eptm.edge_df["srce"] == v11]["trgt"]) - {v10}
verts = v10_out.union(v11_out)
ax.scatter(*eptm.vert_df.loc[v10_out, coords].values.T)
ax.scatter(*eptm.vert_df.loc[v11_out, coords].values.T)
x, y = coords
for _, edge in eptm.edge_df.query(f"srce == {v10}").iterrows():
ax.plot(
edge[["s" + x, "t" + x]],
edge[["s" + y, "t" + y]],
lw=3,
alpha=0.3,
c="r",
)
for _, edge in eptm.edge_df.query(f"srce == {v11}").iterrows():
ax.plot(
edge[["s" + x, "t" + x]],
edge[["s" + y, "t" + y]],
"k--",
)
for v in verts:
for _, edge in eptm.edge_df.query(f"srce == {v}").iterrows():
if edge["trgt"] in {v10, v11}:
continue
ax.plot(
edge[["s" + x, "t" + x]],
edge[["s" + y, "t" + y]],
"k",
lw=0.4,
)
fig.set_size_inches(12, 12)
return fig, ax