import os
import numpy as np
from fury import actor
from fury.lib import FloatArray, Texture
from fury.shaders import (
    attribute_to_actor,
    compose_shader,
    import_fury_shader,
    shader_to_actor,
)
from fury.texture.utils import uv_calculations
from fury.utils import minmax_norm, numpy_to_vtk_image_data, set_polydata_tcoords
[docs]
def sh_odf(centers, coeffs, degree, sh_basis, scales, opacity):
    """
    Visualize one or many ODFs with different features.
    Parameters
    ----------
    centers : ndarray(N, 3)
        ODFs positions.
    coeffs : ndarray
        2D ODFs array in SH coefficients.
    sh_basis: str, optional
        Type of basis (descoteaux, tournier)
        'descoteaux' for the default ``descoteaux07`` DIPY basis.
        'tournier' for the default ``tournier07`` DIPY basis.
    degree: int, optional
        Index of the highest used band of the spherical harmonics basis. Must
        be even, at least 2 and at most 12.
    scales : float or ndarray (N, )
        ODFs size.
    opacity : float
        Takes values from 0 (fully transparent) to 1 (opaque).
    Returns
    -------
    box_actor: Actor
    """
    odf_actor = actor.box(centers=centers, scales=scales)
    odf_actor.GetMapper().SetVBOShiftScaleMethod(False)
    odf_actor.GetProperty().SetOpacity(opacity)
    big_centers = np.repeat(centers, 8, axis=0)
    attribute_to_actor(odf_actor, big_centers, "center")
    minmax = np.array([coeffs.min(axis=1), coeffs.max(axis=1)]).T
    big_minmax = np.repeat(minmax, 8, axis=0)
    attribute_to_actor(odf_actor, big_minmax, "minmax")
    odf_actor_pd = odf_actor.GetMapper().GetInput()
    n_glyphs = coeffs.shape[0]
    # Coordinates to locate the data of each glyph in the texture.
    uv_vals = np.array(uv_calculations(n_glyphs))
    num_pnts = uv_vals.shape[0]
    # Definition of texture coordinates to be associated with the actor.
    t_coords = FloatArray()
    t_coords.SetNumberOfComponents(2)
    t_coords.SetNumberOfTuples(num_pnts)
    [t_coords.SetTuple(i, uv_vals[i]) for i in range(num_pnts)]
    set_polydata_tcoords(odf_actor_pd, t_coords)
    # The coefficient data is stored in a texture to be passed to the shaders.
    # Data is normalized to a range of 0 to 1.
    arr = minmax_norm(coeffs)
    # Data is turned into values within the RGB color range, and then converted
    # into a vtk image data.
    arr *= 255
    grid = numpy_to_vtk_image_data(arr.astype(np.uint8))
    # Vtk image data is associated to a texture.
    texture = Texture()
    texture.SetInputDataObject(grid)
    texture.Update()
    # Texture is associated with the actor
    odf_actor.GetProperty().SetTexture("texture0", texture)
    max_num_coeffs = coeffs.shape[-1]
    max_sh_degree = int((np.sqrt(8 * max_num_coeffs + 1) - 3) / 2)
    max_poly_degree = 2 * max_sh_degree + 2
    viz_sh_degree = max_sh_degree
    # The number of coefficients is associated to the order of the SH
    odf_actor.GetShaderProperty().GetFragmentCustomUniforms().SetUniformf(
        "shDegree", viz_sh_degree
    )
    # Start of shader implementation
    vs_dec = """
        uniform float shDegree;
        in vec3 center;
        in vec2 minmax;
        flat out float numCoeffsVSOutput;
        flat out float maxPolyDegreeVSOutput;
        out vec4 vertexMCVSOutput;
        out vec3 centerMCVSOutput;
        out vec2 minmaxVSOutput;
        out vec3 camPosMCVSOutput;
        out vec3 camRightMCVSOutput;
        out vec3 camUpMCVSOutput;
        """
    vs_impl = """
        numCoeffsVSOutput = (shDegree + 1) * (shDegree + 2) / 2;
        maxPolyDegreeVSOutput = 2 * shDegree + 2;
        vertexMCVSOutput = vertexMC;
        centerMCVSOutput = center;
        minmaxVSOutput = minmax;
        camPosMCVSOutput = -MCVCMatrix[3].xyz * mat3(MCVCMatrix);
        camRightMCVSOutput = vec3(
            MCVCMatrix[0][0], MCVCMatrix[1][0], MCVCMatrix[2][0]);
        camUpMCVSOutput = vec3(
            MCVCMatrix[0][1], MCVCMatrix[1][1], MCVCMatrix[2][1]);
        """
    shader_to_actor(odf_actor, "vertex", decl_code=vs_dec, impl_code=vs_impl)
    # The index of the highest used band of the spherical harmonics basis. Must
    # be even, at least 2 and at most 12.
    def_sh_degree = f"#define SH_DEGREE {max_sh_degree}"
    # The number of spherical harmonics basis functions
    def_sh_count = f"#define SH_COUNT {max_num_coeffs}"
    # Degree of polynomials for which we have to find roots
    def_max_degree = f"#define MAX_DEGREE {max_poly_degree}"
    # If GL_EXT_control_flow_attributes is available, these defines should be
    # defined as [[unroll]] and [[loop]] to give reasonable hints to the
    # compiler. That avoids register spilling, which makes execution
    # considerably faster.
    def_gl_ext_control_flow_attributes = """
        #ifndef _unroll_
            #define _unroll_
        #endif
        #ifndef _loop_
            #define _loop_
        #endif
        """
    # When there are fewer intersections/roots than theoretically possible,
    # some array entries are set to this value
    def_no_intersection = "#define NO_INTERSECTION 3.4e38"
    # pi and its reciprocal
    def_pis = """
        #define M_PI 3.141592653589793238462643
        #define M_INV_PI 0.318309886183790671537767526745
        """
    fs_vs_vars = """
        flat in float numCoeffsVSOutput;
        flat in float maxPolyDegreeVSOutput;
        in vec4 vertexMCVSOutput;
        in vec3 centerMCVSOutput;
        in vec2 minmaxVSOutput;
        in vec3 camPosMCVSOutput;
        in vec3 camRightMCVSOutput;
        in vec3 camUpMCVSOutput;
        """
    coeffs_norm = import_fury_shader(os.path.join("utils", "minmax_norm.glsl"))
    eval_sh_composed = ""
    for i in range(2, max_sh_degree + 1, 2):
        eval_sh = import_fury_shader(
            os.path.join("ray_tracing", "odf", sh_basis, "eval_sh_" + str(i) + ".frag")
        )
        eval_sh_grad = import_fury_shader(
            os.path.join(
                "ray_tracing", "odf", sh_basis, "eval_sh_grad_" + str(i) + ".frag"
            )
        )
        eval_sh_composed = compose_shader([eval_sh_composed, eval_sh, eval_sh_grad])
    # Searches a single root of a polynomial within a given interval.
    #   param out_root The location of the found root.
    #   param out_end_value The value of the given polynomial at end.
    #   param poly Coefficients of the polynomial for which a root should be
    #       found.
    #       Coefficient poly[i] is multiplied by x^i.
    #   param begin The beginning of an interval where the polynomial is
    #       monotonic.
    #   param end The end of said interval.
    #   param begin_value The value of the given polynomial at begin.
    #   param error_tolerance The error tolerance for the returned root
    #       location.
    #       Typically the error will be much lower but in theory it can be
    #       bigger.
    #
    #   return true if a root was found, false if no root exists.
    newton_bisection = import_fury_shader(
        os.path.join("utils", "newton_bisection.frag")
    )
    # Finds all roots of the given polynomial in the interval [begin, end] and
    # writes them to out_roots. Some entries will be NO_INTERSECTION but other
    # than that the array is sorted. The last entry is always NO_INTERSECTION.
    find_roots = import_fury_shader(os.path.join("utils", "find_roots.frag"))
    # Evaluates the spherical harmonics basis in bands 0, 2, ..., SH_DEGREE.
    # Conventions are as in the following paper.
    # M. Descoteaux, E. Angelino, S. Fitzgibbons, and R. Deriche. Regularized,
    # fast, and robust analytical q-ball imaging. Magnetic Resonance in
    # Medicine, 58(3), 2007. https://doi.org/10.1002/mrm.21277
    #   param outSH Values of SH basis functions in bands 0, 2, ...,
    #       SH_DEGREE in this order.
    #   param point The point on the unit sphere where the basis should be
    #       evaluated.
    eval_sh = import_fury_shader(os.path.join("ray_tracing", "odf", "eval_sh.frag"))
    # Evaluates the gradient of each basis function given by eval_sh() and the
    # basis itself
    eval_sh_grad = import_fury_shader(
        os.path.join("ray_tracing", "odf", "eval_sh_grad.frag")
    )
    # Outputs a matrix that turns equidistant samples on the unit circle of a
    # homogeneous polynomial into coefficients of that polynomial.
    get_inv_vandermonde = import_fury_shader(
        os.path.join("ray_tracing", "odf", "get_inv_vandermonde.frag")
    )
    # Determines all intersections between a ray and a spherical harmonics
    # glyph.
    #   param out_ray_params The ray parameters at intersection points. The
    #       points themselves are at ray_origin + out_ray_params[i] * ray_dir.
    #       Some entries may be NO_INTERSECTION but other than that the array
    #       is sorted.
    #   param sh_coeffs SH_COUNT spherical harmonic coefficients defining the
    #       glyph. Their exact meaning is defined by eval_sh().
    #   param ray_origin The origin of the ray, relative to the glyph center.
    #   param ray_dir The normalized direction vector of the ray.
    ray_sh_glyph_intersections = import_fury_shader(
        os.path.join("ray_tracing", "odf", "ray_sh_glyph_intersections.frag")
    )
    # Provides a normalized normal vector for a spherical harmonics glyph.
    #   param sh_coeffs SH_COUNT spherical harmonic coefficients defining the
    #       glyph. Their exact meaning is defined by eval_sh().
    #   param point A point on the surface of the glyph, relative to its
    #       center.
    #
    #   return A normalized surface normal pointing away from the origin.
    get_sh_glyph_normal = import_fury_shader(
        os.path.join("ray_tracing", "odf", "get_sh_glyph_normal.frag")
    )
    # Applies the non-linearity that maps linear RGB to sRGB
    linear_to_srgb = import_fury_shader(os.path.join("lighting", "linear_to_srgb.frag"))
    # Inverse of linear_to_srgb()
    srgb_to_linear = import_fury_shader(os.path.join("lighting", "srgb_to_linear.frag"))
    # Turns a linear RGB color (i.e. rec. 709) into sRGB
    linear_rgb_to_srgb = import_fury_shader(
        os.path.join("lighting", "linear_rgb_to_srgb.frag")
    )
    # Inverse of linear_rgb_to_srgb()
    srgb_to_linear_rgb = import_fury_shader(
        os.path.join("lighting", "srgb_to_linear_rgb.frag")
    )
    # Logarithmic tonemapping operator. Input and output are linear RGB.
    tonemap = import_fury_shader(os.path.join("lighting", "tonemap.frag"))
    # Blinn-Phong illumination model
    blinn_phong_model = import_fury_shader(
        os.path.join("lighting", "blinn_phong_model.frag")
    )
    # fmt: off
    fs_dec = compose_shader([
        def_sh_degree, def_sh_count, def_max_degree,
        def_gl_ext_control_flow_attributes, def_no_intersection, def_pis,
        fs_vs_vars, coeffs_norm, eval_sh_composed, newton_bisection, find_roots,
        eval_sh, eval_sh_grad, get_inv_vandermonde, ray_sh_glyph_intersections,
        get_sh_glyph_normal, blinn_phong_model, linear_to_srgb, srgb_to_linear,
        linear_rgb_to_srgb, srgb_to_linear_rgb, tonemap
    ])
    # fmt: on
    shader_to_actor(odf_actor, "fragment", decl_code=fs_dec)
    point_from_vs = "vec3 pnt = vertexMCVSOutput.xyz;"
    # Ray origin is the camera position in world space
    ray_origin = "vec3 ro = camPosMCVSOutput;"
    # Ray direction is the normalized difference between the fragment and the
    # camera position/ray origin
    ray_direction = "vec3 rd = normalize(pnt - ro);"
    # Light direction in a retroreflective model is the normalized difference
    # between the camera position/ray origin and the fragment
    light_direction = "vec3 ld = normalize(ro - pnt);"
    # Define SH coefficients (measured up to band 8, noise beyond that)
    sh_coeffs = """
        float i = 1 / (numCoeffsVSOutput * 2);
        float shCoeffs[SH_COUNT];
        for(int j=0; j < numCoeffsVSOutput; j++){
            shCoeffs[j] = rescale(
                texture(
                    texture0,
                    vec2(i + j / numCoeffsVSOutput, tcoordVCVSOutput.y)).x,
                    0, 1, minmaxVSOutput.x, minmaxVSOutput.y
            );
        }
        """
    # Perform the intersection test
    intersection_test = """
        float rayParams[MAX_DEGREE];
        rayGlyphIntersections(
            rayParams, shCoeffs, ro - centerMCVSOutput, rd, int(shDegree),
            int(numCoeffsVSOutput), int(maxPolyDegreeVSOutput), M_PI,
            NO_INTERSECTION
        );
        """
    # Identify the first intersection
    first_intersection = """
        float firstRayParam = NO_INTERSECTION;
        _unroll_
        for (int i = 0; i != maxPolyDegreeVSOutput; ++i) {
            if (rayParams[i] != NO_INTERSECTION && rayParams[i] > 0.0) {
                firstRayParam = rayParams[i];
                break;
            }
        }
        """
    # Evaluate shading for a directional light
    directional_light = """
        vec3 color = vec3(1.);
        if (firstRayParam != NO_INTERSECTION) {
            vec3 intersection = ro - centerMCVSOutput + firstRayParam * rd;
            vec3 normal = getShGlyphNormal(shCoeffs, intersection,
                          int(shDegree), int(numCoeffsVSOutput));
            vec3 colorDir = srgbToLinearRgb(abs(normalize(intersection)));
            float attenuation = dot(ld, normal);
            color = blinnPhongIllumModel(
                attenuation, lightColor0, colorDir, specularPower,
                specularColor, ambientColor);
        } else {
            discard;
        }
        """
    frag_output = """
        vec3 outColor = linearRgbToSrgb(tonemap(color));
        fragOutput0 = vec4(outColor, opacity);
        """
    fs_impl = compose_shader(
        [
            point_from_vs,
            ray_origin,
            ray_direction,
            light_direction,
            sh_coeffs,
            intersection_test,
            first_intersection,
            directional_light,
            frag_output,
        ]
    )
    shader_to_actor(odf_actor, "fragment", impl_code=fs_impl, block="picking")
    return odf_actor