# -*- coding: utf-8 -*-
import numpy as np
from fury.colormap import create_colormap
from fury.decorators import warn_on_args_to_kwargs
from fury.lib import Actor, PolyData, PolyDataMapper
from fury.utils import (
    apply_affine,
    set_polydata_colors,
    set_polydata_triangles,
    set_polydata_vertices,
)
[docs]
class OdfSlicerActor(Actor):
    """VTK actor for visualizing slices of ODF field.
    Parameters
    ----------
    odfs : ndarray
        SF or SH coefficients 2-dimensional array.
    vertices: ndarray
        The sphere vertices used for SH to SF projection.
    faces: ndarray
        Indices of sphere vertices forming triangles. Should be
        ordered clockwise (see fury.utils.fix_winding_order).
    indices: tuple
        Indices given in tuple(x_indices, y_indices, z_indices)
        format for mapping 2D ODF array to 3D voxel grid.
    scale : float
        Multiplicative factor to apply to ODF amplitudes.
    norm : bool
        Normalize SF amplitudes so that the maximum
        ODF amplitude per voxel along a direction is 1.
    radial_scale : bool
        Scale sphere points by ODF values.
    global_cm : bool
        If True the colormap will be applied in all ODFs. If False
        it will be applied individually at each voxel.
    colormap : None or str
        The name of the colormap to use. Matplotlib colormaps are supported
        (e.g., 'inferno'). If None then a RGB colormap is used.
    opacity : float
        Takes values from 0 (fully transparent) to 1 (opaque).
    affine : array
        optional 4x4 transformation array from native
        coordinates to world coordinates.
    B : ndarray (n_coeffs, n_vertices)
        Optional SH to SF matrix for projecting `odfs` given in SH
        coefficients on the `sphere`. If None, then the input is assumed
        to be expressed in SF coefficients.
    """
    @warn_on_args_to_kwargs()
    def __init__(
        self,
        odfs,
        vertices,
        faces,
        indices,
        scale,
        norm,
        radial_scale,
        shape,
        global_cm,
        colormap,
        opacity,
        *,
        affine=None,
        B=None,
    ):
        self.vertices = vertices
        self.faces = faces
        self.odfs = odfs
        self.indices = indices
        self.B = B
        self.radial_scale = radial_scale
        self.colormap = colormap
        self.grid_shape = shape
        self.global_cm = global_cm
        # declare a mask to be instantiated in slice_along_axis
        self.mask = None
        # If a B matrix is given, odfs are expected to
        # be in SH basis coefficients.
        if self.B is not None:
            # In that case, we need to save our normalisation and scale
            # to apply them after conversion from SH to SF.
            self.norm = norm
            self.scale = scale
        else:
            # If our input is in SF coefficients, we can normalise and
            # scale it only once, here.
            if norm:
                self.odfs /= np.abs(self.odfs).max(axis=-1, keepdims=True)
            self.odfs *= scale
        # Compute world coordinates of an affine is supplied
        self.affine = affine
        if self.affine is not None:
            self.w_verts = self.vertices.dot(affine[:3, :3])
            self.w_pos = apply_affine(affine, np.asarray(self.indices).T)
        # Initialize mapper and slice to the
        # middle of the volume along Z axis
        self.mapper = PolyDataMapper()
        self.SetMapper(self.mapper)
        self.slice_along_axis(self.grid_shape[-1] // 2)
        self.set_opacity(opacity)
[docs]
    def set_opacity(self, opacity):
        """Set opacity value of ODFs to display."""
        self.GetProperty().SetOpacity(opacity) 
[docs]
    def display_extent(self, x1, x2, y1, y2, z1, z2):
        """Set visible volume from x1 (inclusive) to x2 (inclusive),
        y1 (inclusive) to y2 (inclusive), z1 (inclusive) to z2
        (inclusive).
        """
        mask = np.zeros(self.grid_shape, dtype=bool)
        mask[x1 : x2 + 1, y1 : y2 + 1, z1 : z2 + 1] = True
        self.mask = mask
        self._update_mapper() 
[docs]
    @warn_on_args_to_kwargs()
    def slice_along_axis(self, slice_index, *, axis="zaxis"):
        """Slice ODF field at given `slice_index` along axis
        in ['xaxis', 'yaxis', zaxis'].
        """
        if axis == "xaxis":
            self.display_extent(
                slice_index,
                slice_index,
                0,
                self.grid_shape[1] - 1,
                0,
                self.grid_shape[2] - 1,
            )
        elif axis == "yaxis":
            self.display_extent(
                0,
                self.grid_shape[0] - 1,
                slice_index,
                slice_index,
                0,
                self.grid_shape[2] - 1,
            )
        elif axis == "zaxis":
            self.display_extent(
                0,
                self.grid_shape[0] - 1,
                0,
                self.grid_shape[1] - 1,
                slice_index,
                slice_index,
            )
        else:
            raise ValueError("Invalid axis name {0}.".format(axis)) 
[docs]
    @warn_on_args_to_kwargs()
    def display(self, *, x=None, y=None, z=None):
        """Display a slice along x, y, or z axis."""
        if x is None and y is None and z is None:
            self.slice_along_axis(self.grid_shape[2] // 2)
        elif x is not None:
            self.slice_along_axis(x, axis="xaxis")
        elif y is not None:
            self.slice_along_axis(y, axis="yaxis")
        elif z is not None:
            self.slice_along_axis(z, axis="zaxis") 
[docs]
    def update_sphere(self, vertices, faces, B):
        """Dynamically change the sphere used for SH to SF projection."""
        if self.B is None:
            raise ValueError("Can't update sphere when using " "SF coefficients.")
        self.vertices = vertices
        if self.affine is not None:
            self.w_verts = self.vertices.dot(self.affine[:3, :3])
        self.faces = faces
        self.B = B
        # draw ODFs with new sphere
        self._update_mapper() 
    def _update_mapper(self):
        """Map vtkPolyData to the actor."""
        polydata = PolyData()
        offsets = self._get_odf_offsets(self.mask)
        if len(offsets) == 0:
            self.mapper.SetInputData(polydata)
            return None
        sph_dirs = self._get_sphere_directions()
        sf = self._get_sf(self.mask)
        all_vertices = self._get_all_vertices(offsets, sph_dirs, sf)
        all_faces = self._get_all_faces(len(offsets), len(sph_dirs))
        all_colors = self._generate_color_for_vertices(sf)
        # TODO: There is a lot of deep copy here.
        # Optimize (see viz_network.py example).
        set_polydata_triangles(polydata, all_faces)
        set_polydata_vertices(polydata, all_vertices)
        set_polydata_colors(polydata, all_colors)
        self.mapper.SetInputData(polydata)
    def _get_odf_offsets(self, mask):
        """Get the position of non-zero voxels inside `mask`."""
        if self.affine is not None:
            return self.w_pos[mask[self.indices]]
        return np.asarray(self.indices).T[mask[self.indices]]
    def _get_sphere_directions(self):
        """Get the sphere directions onto which is projected the signal."""
        if self.affine is not None:
            return self.w_verts
        return self.vertices
    def _get_sf(self, mask):
        """Get SF coefficients inside `mask`."""
        # when odfs are expressed in SH coefficients
        if self.B is not None:
            sf = self.odfs[mask[self.indices]].dot(self.B)
            # normalisation and scaling is done on SF coefficients
            if self.norm:
                sf /= np.abs(sf).max(axis=-1, keepdims=True)
            return sf * self.scale
        # when odfs are in SF coefficients, the normalisation and scaling
        # are done during initialisation. We simply return them:
        return self.odfs[mask[self.indices]]
    def _get_all_vertices(self, offsets, sph_dirs, sf):
        """Get array of all the vertices of the ODFs to display."""
        if self.radial_scale:
            # apply SF amplitudes to all sphere
            # directions and offset each voxel
            return np.tile(sph_dirs, (len(offsets), 1)) * sf.reshape(-1, 1) + np.repeat(
                offsets, len(sph_dirs), axis=0
            )
        # return scaled spheres offsetted by `offsets`
        return np.tile(sph_dirs, (len(offsets), 1)) * self.scale + np.repeat(
            offsets, len(sph_dirs), axis=0
        )
    def _get_all_faces(self, nb_odfs, nb_dirs):
        """Get array of all the faces of the ODFs to display."""
        return np.tile(self.faces, (nb_odfs, 1)) + np.repeat(
            np.arange(nb_odfs) * nb_dirs, len(self.faces)
        ).reshape(-1, 1)
    def _generate_color_for_vertices(self, sf):
        """Get array of all vertices colors."""
        if self.global_cm:
            if self.colormap is None:
                raise IOError("if global_cm=True, colormap must be defined.")
            else:
                all_colors = create_colormap(sf.ravel(), name=self.colormap) * 255
        elif self.colormap is not None:
            if isinstance(self.colormap, str):
                # Map ODFs values [min, max] to [0, 1] for each ODF
                range_sf = sf.max(axis=-1) - sf.min(axis=-1)
                rescaled = sf - sf.min(axis=-1, keepdims=True)
                rescaled[range_sf > 0] /= range_sf[range_sf > 0][..., None]
                all_colors = create_colormap(rescaled.ravel(), name=self.colormap) * 255
            else:
                all_colors = np.tile(
                    np.array(self.colormap).reshape(1, 3),
                    (sf.shape[0] * sf.shape[1], 1),
                )
        else:
            all_colors = np.tile(np.abs(self.vertices) * 255, (len(sf), 1))
        return all_colors.astype(np.uint8)