#ifndef COMMONS_SPHERE_SAMPLING_H
#define COMMONS_SPHERE_SAMPLING_H

mat3 matrixFromVector(vec3 n) { // frisvad
	if (n.z == -1.0)
		n.z = -0.995;
    float a = 1.0 / (1.0 + n.z);
    float b = -n.x * n.y * a;
    vec3 b1 = vec3(1.0 - n.x * n.x * a, b, -n.x);
    vec3 b2 = vec3(b, 1.0 - n.y * n.y * a, -n.y);
    return mat3(b1, b2, n);
}

vec3 CosineSampleHemisphere(float u1, float u2)
{
    float r = sqrt(u1);
    float theta = 2.0 * M_PI * u2;
    float x = r * cos(theta);
    float y = r * sin(theta);
 
    return vec3(x, y, sqrt(max(0.0f, 1.0 - u1)));
}

// Helper function: sample the visible hemisphere from a spherical cap
// Sampling Visible GGX Normals with Spherical Caps: Jonathan Dupuy Anis Benyoub
// https://gist.github.com/jdupuy/4c6e782b62c92b9cb3d13fbb0a5bd7a0
vec3 SampleVndf_Hemisphere(vec2 u, vec3 wi)
{
	// sample a spherical cap in (-wi.z, 1]
	float phi = 2.0f * M_PI * u.x;
	float z = fma((1.0f - u.y), (1.0f + wi.z), -wi.z);
	float sinTheta = sqrt(clamp(1.0f - z * z, 0.0f, 1.0f));
	float x = sinTheta * cos(phi);
	float y = sinTheta * sin(phi);
	vec3 c = vec3(x, y, z);
	// compute halfway direction;
	vec3 h = c + wi;
	// return without normalization as this is done later (see line 25)
	return h;
}

// Sample the GGX VNDF
vec3 SampleVndf_GGX(vec2 u, vec3 wi, vec2 alpha)
{
    // warp to the hemisphere configuration
    vec3 wiStd = normalize(vec3(wi.xy * alpha, wi.z));
    // sample the hemisphere
    vec3 wmStd = SampleVndf_Hemisphere(u, wiStd);
    // warp back to the ellipsoid configuration
    vec3 wm = normalize(vec3(wmStd.xy * alpha, wmStd.z));
    // return final normal
    return wm;
}

float square(float v)
{
	return v * v;
}

// This is similar to the code used in the matrixFromVector() just uses different
// coordinate system? This version is required to be used with ImportanceSampleGGX_VNDF
// TODO: Figure out why they shuffle the axis...
mat3 construct_ONB_frisvad(vec3 normal)
{
    mat3 ret;
    ret[1] = normal;
    if(normal.z < -0.999805696f) {
        ret[0] = vec3(0.0f, -1.0f, 0.0f);
        ret[2] = vec3(-1.0f, 0.0f, 0.0f);
    }
    else {
        float a = 1.0f / (1.0f + normal.z);
        float b = -normal.x * normal.y * a;
        ret[0] = vec3(1.0f - normal.x * normal.x * a, b, -normal.x);
        ret[2] = vec3(b, 1.0f - normal.y * normal.y * a, -normal.y);
    }
    return ret;
}

vec3 ImportanceSampleGGX_VNDF(vec2 u, float roughness, vec3 V, mat3 basis)
{
    float alpha = square(roughness);

    vec3 Ve = -vec3(dot(V, basis[0]), dot(V, basis[2]), dot(V, basis[1]));

    vec3 Vh = normalize(vec3(alpha * Ve.x, alpha * Ve.y, Ve.z));
    
    float lensq = square(Vh.x) + square(Vh.y);
    vec3 T1 = lensq > 0.0 ? vec3(-Vh.y, Vh.x, 0.0) * inversesqrt(lensq) : vec3(1.0, 0.0, 0.0);
    vec3 T2 = cross(Vh, T1);

    float r = sqrt(u.x);
    float phi = 2.0 * M_PI * u.y;
    float t1 = r * cos(phi);
    float t2 = r * sin(phi);
    float s = 0.5 * (1.0 + Vh.z);
    t2 = (1.0 - s) * sqrt(1.0 - square(t1)) + s * t2;

    vec3 Nh = t1 * T1 + t2 * T2 + sqrt(max(0.0, 1.0 - square(t1) - square(t2))) * Vh;

    // Tangent space H
    vec3 Ne = vec3(alpha * Nh.x, max(0.0, Nh.z), alpha * Nh.y);

    // World space H
    return normalize(basis * Ne);
}

float G1_Smith(float roughness, float NdotL)
{
    float alpha = square(roughness);
    return 2.0 * NdotL / (NdotL + sqrt(square(alpha) + (1.0 - square(alpha)) * square(NdotL)));
}

// F0 - reflectance = (n - 1)^2 / (n + 1)^2 where n is index of refraction
vec3 schlick_fresnel(vec3 F0, float HdotV, float specular_factor)
{
    vec3 F = F0 + (vec3(1.0) - F0) * pow(1.0 - HdotV, 5.0);
    F *= specular_factor;
    F = clamp(F, vec3(0.0), vec3(1.0));
    return F;
}

#endif