from os.path import join as pjoin
import numpy as np
from fury.colormap import boys2rgb, colormap_lookup_table, orient2rgb
from fury.decorators import warn_on_args_to_kwargs
from fury.lib import (
VTK_OBJECT,
Actor,
CellArray,
Command,
PolyData,
PolyDataMapper,
calldata_type,
numpy_support,
)
from fury.shaders import (
attribute_to_actor,
compose_shader,
import_fury_shader,
shader_to_actor,
)
from fury.utils import apply_affine, numpy_to_vtk_colors, numpy_to_vtk_points
[docs]
class PeakActor(Actor):
"""FURY actor for visualizing DWI peaks.
Parameters
----------
directions : ndarray
Peak directions. The shape of the array should be (X, Y, Z, D, 3).
indices : tuple
Indices given in tuple(x_indices, y_indices, z_indices)
format for mapping 2D ODF array to 3D voxel grid.
values : ndarray, optional
Peak values. The shape of the array should be (X, Y, Z, D).
affine : array, optional
4x4 transformation array from native coordinates to world coordinates.
colors : None or string ('rgb_standard') or tuple (3D or 4D) or array/ndarray (N, 3 or 4) or array/ndarray (K, 3 or 4) or array/ndarray(N, ) or array/ndarray (K, )
If None a standard orientation colormap is used for every line.
If one tuple of color is used. Then all streamlines will have the same
color.
If an array (N, 3 or 4) is given, where N is equal to the number of
points. Then every point is colored with a different RGB(A) color.
If an array (K, 3 or 4) is given, where K is equal to the number of
lines. Then every line is colored with a different RGB(A) color.
If an array (N, ) is given, where N is the number of points then these
are considered as the values to be used by the colormap.
If an array (K,) is given, where K is the number of lines then these
are considered as the values to be used by the colormap.
lookup_colormap : vtkLookupTable, optional
Add a default lookup table to the colormap. Default is None which calls
:func:`fury.actor.colormap_lookup_table`.
linewidth : float, optional
Line thickness. Default is 1.
symmetric: bool, optional
If True, peaks are drawn for both peaks_dirs and -peaks_dirs. Else,
peaks are only drawn for directions given by peaks_dirs. Default is
True.
""" # noqa: E501
@warn_on_args_to_kwargs()
def __init__(
self,
directions,
indices,
*,
values=None,
affine=None,
colors=None,
lookup_colormap=None,
linewidth=1,
symmetric=True,
):
if affine is not None:
w_pos = apply_affine(affine, np.asarray(indices).T)
valid_dirs = directions[indices]
num_dirs = len(np.nonzero(np.abs(valid_dirs).max(axis=-1) > 0)[0])
pnts_per_line = 2
points_array = np.empty((num_dirs * pnts_per_line, 3))
centers_array = np.empty_like(points_array, dtype=int)
diffs_array = np.empty_like(points_array)
line_count = 0
for idx, center in enumerate(zip(indices[0], indices[1], indices[2])):
if affine is None:
xyz = np.asarray(center)
else:
xyz = w_pos[idx, :]
valid_peaks = np.nonzero(np.abs(valid_dirs[idx, :, :]).max(axis=-1) > 0.0)[
0
]
for direction in valid_peaks:
if values is not None:
pv = values[center][direction]
else:
pv = 1.0
if symmetric:
point_i = directions[center][direction] * pv + xyz
point_e = -directions[center][direction] * pv + xyz
else:
point_i = directions[center][direction] * pv + xyz
point_e = xyz
diff = point_e - point_i
points_array[line_count * pnts_per_line, :] = point_e
points_array[line_count * pnts_per_line + 1, :] = point_i
centers_array[line_count * pnts_per_line, :] = center
centers_array[line_count * pnts_per_line + 1, :] = center
diffs_array[line_count * pnts_per_line, :] = diff
diffs_array[line_count * pnts_per_line + 1, :] = diff
line_count += 1
vtk_points = numpy_to_vtk_points(points_array)
vtk_cells = _points_to_vtk_cells(points_array)
colors_tuple = _peaks_colors_from_points(points_array, colors=colors)
vtk_colors, colors_are_scalars, self.__global_opacity = colors_tuple
poly_data = PolyData()
poly_data.SetPoints(vtk_points)
poly_data.SetLines(vtk_cells)
poly_data.GetPointData().SetScalars(vtk_colors)
self.__mapper = PolyDataMapper()
self.__mapper.SetInputData(poly_data)
self.__mapper.ScalarVisibilityOn()
self.__mapper.SetScalarModeToUsePointFieldData()
self.__mapper.SelectColorArray("colors")
self.__mapper.Update()
self.SetMapper(self.__mapper)
attribute_to_actor(self, centers_array, "center")
attribute_to_actor(self, diffs_array, "diff")
vs_var_dec = """
in vec3 center;
in vec3 diff;
flat out vec3 centerVertexMCVSOutput;
"""
fs_var_dec = """
flat in vec3 centerVertexMCVSOutput;
uniform bool isRange;
uniform vec3 crossSection;
uniform vec3 lowRanges;
uniform vec3 highRanges;
"""
orient_to_rgb = import_fury_shader(pjoin("utils", "orient_to_rgb.glsl"))
visible_cross_section = import_fury_shader(
pjoin("interaction", "visible_cross_section.glsl")
)
visible_range = import_fury_shader(pjoin("interaction", "visible_range.glsl"))
vs_dec = compose_shader([vs_var_dec, orient_to_rgb])
fs_dec = compose_shader([fs_var_dec, visible_cross_section, visible_range])
vs_impl = """
centerVertexMCVSOutput = center;
if (vertexColorVSOutput.rgb == vec3(0))
{
vertexColorVSOutput.rgb = orient2rgb(diff);
}
"""
fs_impl = """
if (isRange)
{
if (!inVisibleRange(centerVertexMCVSOutput))
discard;
}
else
{
if (!inVisibleCrossSection(centerVertexMCVSOutput))
discard;
}
"""
shader_to_actor(self, "vertex", decl_code=vs_dec, impl_code=vs_impl)
shader_to_actor(self, "fragment", decl_code=fs_dec)
shader_to_actor(self, "fragment", impl_code=fs_impl, block="light")
# Color scale with a lookup table
if colors_are_scalars:
if lookup_colormap is None:
lookup_colormap = colormap_lookup_table()
self.__mapper.SetLookupTable(lookup_colormap)
self.__mapper.UseLookupTableScalarRangeOn()
self.__mapper.Update()
self.__lw = linewidth
self.GetProperty().SetLineWidth(self.__lw)
if self.__global_opacity >= 0:
self.GetProperty().SetOpacity(self.__global_opacity)
self.__min_centers = np.min(indices, axis=1)
self.__max_centers = np.max(indices, axis=1)
self.__is_range = True
self.__low_ranges = self.__min_centers
self.__high_ranges = self.__max_centers
self.__cross_section = self.__high_ranges // 2
self.__mapper.AddObserver(
Command.UpdateShaderEvent, self.__display_peaks_vtk_callback(None, None)
)
@warn_on_args_to_kwargs()
@calldata_type(VTK_OBJECT)
def __display_peaks_vtk_callback(self, caller, event, *, calldata=None):
if calldata is not None:
calldata.SetUniformi("isRange", self.__is_range)
calldata.SetUniform3f("highRanges", self.__high_ranges)
calldata.SetUniform3f("lowRanges", self.__low_ranges)
calldata.SetUniform3f("crossSection", self.__cross_section)
[docs]
def display_cross_section(self, x, y, z):
if self.__is_range:
self.__is_range = False
self.__cross_section = [x, y, z]
[docs]
def display_extent(self, x1, x2, y1, y2, z1, z2):
if not self.__is_range:
self.__is_range = True
self.__low_ranges = [x1, y1, z1]
self.__high_ranges = [x2, y2, z2]
@property
def cross_section(self):
return self.__cross_section
@property
def global_opacity(self):
return self.__global_opacity
@global_opacity.setter
def global_opacity(self, opacity):
self.__global_opacity = opacity
self.GetProperty().SetOpacity(self.__global_opacity)
@property
def high_ranges(self):
return self.__high_ranges
@property
def is_range(self):
return self.__is_range
@property
def low_ranges(self):
return self.__low_ranges
@property
def linewidth(self):
return self.__lw
@linewidth.setter
def linewidth(self, linewidth):
self.__lw = linewidth
self.GetProperty().SetLineWidth(self.__lw)
@property
def max_centers(self):
return self.__max_centers
@property
def min_centers(self):
return self.__min_centers
@warn_on_args_to_kwargs()
def _orientation_colors(points, *, cmap="rgb_standard"):
"""
Parameters
----------
points : (N, 3) array or ndarray
points coordinates array.
cmap : string ('rgb_standard', 'boys_standard'), optional
colormap.
Returns
-------
colors_list : ndarray
list of Kx3 colors. Where K is the number of lines.
"""
if cmap.lower() == "rgb_standard":
col_list = [
orient2rgb(points[i + 1] - points[i]) for i in range(0, len(points), 2)
]
elif cmap.lower() == "boys_standard":
col_list = [
boys2rgb(points[i + 1] - points[i]) for i in range(0, len(points), 2)
]
else:
raise ValueError(
"Invalid colormap. The only available options are "
"'rgb_standard' and 'boys_standard'."
)
return np.asarray(col_list)
@warn_on_args_to_kwargs()
def _peaks_colors_from_points(points, *, colors=None, points_per_line=2):
"""Return a VTK scalar array containing colors information for each one of
the peaks according to the policy defined by the parameter colors.
Parameters
----------
points : (N, 3) array or ndarray
points coordinates array.
colors : None or string ('rgb_standard') or tuple (3D or 4D) or
array/ndarray (N, 3 or 4) or array/ndarray (K, 3 or 4) or
array/ndarray(N, ) or array/ndarray (K, )
If None a standard orientation colormap is used for every line.
If one tuple of color is used. Then all streamlines will have the same
color.
If an array (N, 3 or 4) is given, where N is equal to the number of
points. Then every point is colored with a different RGB(A) color.
If an array (K, 3 or 4) is given, where K is equal to the number of
lines. Then every line is colored with a different RGB(A) color.
If an array (N, ) is given, where N is the number of points then these
are considered as the values to be used by the colormap.
If an array (K,) is given, where K is the number of lines then these
are considered as the values to be used by the colormap.
points_per_line : int (1 or 2), optional
number of points per peak direction.
Returns
-------
color_array : vtkDataArray
vtk scalar array with name 'colors'.
colors_are_scalars : bool
indicates whether or not the colors are scalars to be interpreted by a
colormap.
global_opacity : float
returns 1 if the colors array doesn't contain opacity otherwise -1.
"""
num_pnts = len(points)
num_lines = num_pnts // points_per_line
colors_are_scalars = False
global_opacity = 1
if colors is None or colors == "rgb_standard":
# Automatic RGB colors
colors = np.asarray((0, 0, 0))
color_array = numpy_to_vtk_colors(np.tile(255 * colors, (num_pnts, 1)))
elif type(colors) is tuple:
global_opacity = 1 if len(colors) == 3 else -1
colors = np.asarray(colors)
color_array = numpy_to_vtk_colors(np.tile(255 * colors, (num_pnts, 1)))
else:
colors = np.asarray(colors)
if len(colors) == num_lines:
pnts_colors = np.repeat(colors, points_per_line, axis=0)
if colors.ndim == 1: # Scalar per line
color_array = numpy_support.numpy_to_vtk(pnts_colors, deep=True)
colors_are_scalars = True
elif colors.ndim == 2: # RGB(A) color per line
global_opacity = 1 if colors.shape[1] == 3 else -1
color_array = numpy_to_vtk_colors(255 * pnts_colors)
elif len(colors) == num_pnts:
if colors.ndim == 1: # Scalar per point
color_array = numpy_support.numpy_to_vtk(colors, deep=True)
colors_are_scalars = True
elif colors.ndim == 2: # RGB(A) color per point
global_opacity = 1 if colors.shape[1] == 3 else -1
color_array = numpy_to_vtk_colors(255 * colors)
color_array.SetName("colors")
return color_array, colors_are_scalars, global_opacity
@warn_on_args_to_kwargs()
def _points_to_vtk_cells(points, *, points_per_line=2):
"""Return the VTK cell array for the peaks given the set of points
coordinates.
Parameters
----------
points : (N, 3) array or ndarray
points coordinates array.
points_per_line : int (1 or 2), optional
number of points per peak direction.
Returns
-------
cell_array : vtkCellArray
connectivity + offset information.
"""
num_pnts = len(points)
num_cells = num_pnts // points_per_line
cell_array = CellArray()
"""
Connectivity is an array that contains the indices of the points that
need to be connected in the visualization. The indices start from 0.
"""
connectivity = np.asarray(list(range(0, num_pnts)), dtype=int)
"""
Offset is an array that contains the indices of the first point of
each line. The indices start from 0 and given the known geometry of
this actor the creation of this array requires a 2 points padding
between indices.
"""
offset = np.asarray(list(range(0, num_pnts + 1, points_per_line)), dtype=int)
vtk_array_type = numpy_support.get_vtk_array_type(connectivity.dtype)
cell_array.SetData(
numpy_support.numpy_to_vtk(offset, deep=True, array_type=vtk_array_type),
numpy_support.numpy_to_vtk(connectivity, deep=True, array_type=vtk_array_type),
)
cell_array.SetNumberOfCells(num_cells)
return cell_array