#version 460
#extension GL_EXT_debug_printf : enable

#extension GL_EXT_shader_explicit_arithmetic_types : enable
#extension GL_ARB_shader_ballot : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_EXT_control_flow_attributes : enable

// TODO: this should be made conditional
#extension GL_EXT_fragment_shading_rate : enable

#define HAS_16BIT_TYPES

layout(early_fragment_tests) in;

#define RT_READ_ONLY 1

uniform sampler2DArray s_BlueNoise;
uniform sampler2D      s_BRDF;

#define RT_TRAVERSAL_HAS_USER_RAY_STATE
struct ray_state_user_data
{
	uint primitiveId;
};

#include <shaders/materials/commons.glsl>
#include <shaders/commons_hlsl.glsl>
#include <shaders/materials/commons_sphere_sampling.glsl>
#include <shaders/geometry_partitioning/raytrace_buffers.glsl>
#include <shaders/geometry_partitioning/raytrace_commons.glsl>

#include <shaders/materials/raytrace_simple_traversal.glsl>

layout(std140, row_major) uniform TransformParamsBuffer{
	EntityTransformParams transform_params;
};

//#pragma optionNV(fastmath on)
//#pragma optionNV(fastprecision off)
//#pragma optionNV(ifcvt none)
//#pragma optionNV(inline all)
//#pragma optionNV(strict on)
//#pragma optionNV(unroll 5)

layout(location = 1) in struct
{
	vec3 vCoords;
	vec3 vNorm;
	vec3 vWorldNorm;
	vec3 vLocalPos;
	vec3 vCameraRelativeWorldPos;
	f16vec4 vColor;
	f16vec2 vUV0;
} vtx_input;

// Marks rays which didn't hit solid object
//#define MARK_UNFINISHED 1
//#define VISUALIZE_GRID 1
//#define VISUALIZE_HEATMAP
//#define INNER_REFLECTION
#define INTERPOLATE_NORMALS 1

#define MAX_BOUNCES 1

#ifndef MAX_TRACE_LENGTH
#define MAX_TRACE_LENGTH 1024
#endif

struct RTSetup
{
	mat4  mat_projection;
	mat4  mat_model;
	vec3  camera_position;
	int   screen_sampling_scale;
	vec4  camera_projection_params;
	vec4  near_far_plane;
	vec2  frustum_shift;
	vec2  resolution;

	uint64_t buff_grid_markers_ptr;

	float trace_range_primary;
	float trace_range_secondary;

	float roughness_clamp;
	float env_map_intensity;

	int   lights_num;

	float initial_face_start_distance;

	uint  material_flags;
	int   reflect_rays;
};

layout (std140, row_major) uniform RTSetupBuffer
{
	RTSetup rt_setup;
};

layout(location = 0) out vec4 outHitDirectionPrimitiveID;	// RGBA_FLOAT - output it as a RT because we might be using VRS
layout(location = 1) out vec4 outDebug;

uniform sampler2D sFresnelReflection;

layout(r32ui)    uniform readonly uimage2D imNormalMaterial;
layout(rgba16ui) uniform readonly uimage2D imMetalnessRoughnessMaterialTags;
uniform sampler2D sTextureDepth;

float linearizeDepth(in float d)
{
	return rt_setup.near_far_plane.z / (rt_setup.near_far_plane.y + rt_setup.near_far_plane.x - d * rt_setup.near_far_plane.w);
}

vec3 positionFromDepth(vec3 vDirection, float depth)
{
	return vDirection.xyz * depth;
}

//

//
#define GRID_SIZE in_bbox_data.grid_size_raytrace.xyz
#define GRID_SIZE_RECIP in_bbox_data.grid_size_raytrace_recip.xyz

vec3 TurboColormap(in float x)
{
  const vec4 kRedVec4 = vec4(0.13572138, 4.61539260, -42.66032258, 132.13108234);
  const vec4 kGreenVec4 = vec4(0.09140261, 2.19418839, 4.84296658, -14.18503333);
  const vec4 kBlueVec4 = vec4(0.10667330, 12.64194608, -60.58204836, 110.36276771);
  const vec2 kRedVec2 = vec2(-152.94239396, 59.28637943);
  const vec2 kGreenVec2 = vec2(4.27729857, 2.82956604);
  const vec2 kBlueVec2 = vec2(-89.90310912, 27.34824973);
  
  x = clamp(x, 0.0, 1.0);
  vec4 v4 = vec4( 1.0, x, x * x, x * x * x);
  vec2 v2 = v4.zw * v4.z;
  return vec3(
	dot(v4, kRedVec4)   + dot(v2, kRedVec2),
	dot(v4, kGreenVec4) + dot(v2, kGreenVec2),
	dot(v4, kBlueVec4)  + dot(v2, kBlueVec2)
  );
}

vec3 glass_refract(vec3 v, vec3 n)
{
	// TODO: add param for air->glass vs glass->air
	// Here value for air->glass
	//return refract(v, n, 1.0/1.5);
	//return n;

	float s = dot(v, n) < 0.0 ? 1.0 : -1.0;
	vec3 new_v = refract(v, n * s, 1.0 / 1.25);

	//return v;
	if (dot(new_v, new_v) == 0.0)
	{
		return v;
	}

	return new_v;
}

void evaluate_material(in out ray_state state, in vec3 prev_state_origin, int hit_face, uint hit_material_flags, f16vec2 bc, bool flip_normal_on_glass)
{
	state.user_data.primitiveId = hit_face;

	state.flags = rt_clear_mask(state.flags, TF_RUNNING);
	state.flags = rt_set_mask(state.flags, TF_HIT);
	//state.running = false;
	//state.hit     = true;
}

//#define findClosest findClosestNaive
#define findClosest findClosestDDA

vec3 rt_randomize_dir_for_roughness(ivec2 screen_pos, vec3 dir, vec3 n, float roughness, out float bounce_throughput)
{
	bounce_throughput = 1.0;
#if 1
	// NOTE: This is REALLY costly when divergens goes to hell, so for now because we don't cluster
	// rays just try to limit the roughness...
	if (roughness > 0.0f)
	{
		roughness = min(rt_setup.roughness_clamp, roughness);
		roughness = dot(dir, n) * roughness;
		bounce_throughput = bounce_throughput * max(0.0, 1.0 - roughness);

		const float golden_ratio = 1.61803398875;
		int frame = globals.monotonic & 127;
		float clamped_roughness = roughness * roughness;

		screen_pos = screen_pos & ivec2(127);
		vec2 noise = texelFetch(s_BlueNoise, ivec3(screen_pos, 0), 0).rg;
		vec2 hash = fract(noise + frame * golden_ratio);
		//vec2 hash = fract(texelFetch(s_BlueNoise, ivec3(screen_pos, 0), 0).rg);

		vec3 d = CosineSampleHemisphere(hash.x * clamped_roughness, hash.y);
		mat3 vecSpace = matrixFromVector(dir);
		d = vecSpace * d;

		// check if d potentially traces 'into the floor'. this happens with high roughness and hemisphere that is highly rotated and goes into the floor...
		float VdotN = dot(d, n);
		if (VdotN < 0.0 && true)
		{
			// try another one... if this correct even? 'Stochastic Screen-Space Reflections' by Stachowiak (siggraph, 2015, page 43 mention they re-generate)
			// seems to help to some degree at least
			hash = fract(noise + (frame + 10) * golden_ratio);
			d = CosineSampleHemisphere(hash.x * clamped_roughness, hash.y);
			d = vecSpace * d;
			VdotN = dot(d, n);
			if (VdotN > 0.0)
				dir = d;
		}
		else
		{
			dir = d;
		}
	}
#endif

	return dir;
}

vec3 rt_reflect_dir_for_roughness_sample_ggx(ivec2 screen_pos, vec3 dir, vec3 n, float roughness, out float bounce_throughput)
{
	const float golden_ratio = 1.61803398875;
	int frame = globals.monotonic & 127;
	float clamped_roughness = roughness;
	//clamped_roughness = clamped_roughness * clamped_roughness;
	clamped_roughness = min(rt_setup.roughness_clamp, clamped_roughness * clamped_roughness);

	screen_pos = screen_pos & ivec2(127);
	vec2 noise = texelFetch(s_BlueNoise, ivec3(screen_pos, 0), 0).rg;
	vec2 hash = fract(noise + frame * golden_ratio);

	//mat3 basis = matrixFromVector(n);
	mat3 basis = construct_ONB_frisvad(n);	// NOTE: Has to be this one...

	vec3 V = dir;
	vec3 H = normalize(basis * ImportanceSampleGGX_VNDF(hash, clamped_roughness, V, basis));
	vec3 L = reflect(V, H);

	float NoL = max(0.0, dot(n, L));
	float NoV = max(0.0, -dot(n, V));
	float VoH = max(0, -dot(V, H));

	if (NoL < 0.0 || NoV < 0.0)
	{
		hash = fract(noise + (frame + 10) * golden_ratio);
		H = normalize(basis * ImportanceSampleGGX_VNDF(hash, clamped_roughness, V, basis));
		L = reflect(V, H);

		NoL = max(0.0, dot(n, L));
		NoV = max(0.0, -dot(n, V));
		VoH = max(0, -dot(V, H));
	}

	if (NoL > 0.0 && NoV > 0.0)
	{
		// See the Heitz paper referenced above for the estimator explanation.
		//   (BRDF / PDF) = F * G2(V, L) / G1(V)
		// Assume G2 = G1(V) * G1(L) here and simplify that expression to just G1(L).
					
		float G1_NoL = G1_Smith(clamped_roughness, NoL);
		vec3 F = schlick_fresnel(vec3(0.04, 0.04, 0.04), VoH, 1.0);

		bounce_throughput = length(G1_NoL * F);
		return normalize(L);
	}

	bounce_throughput = 0.0;
	return dir;
}


void main()
{
	outHitDirectionPrimitiveID = vec4(0.0, 0.0, 0.0, asfloat(-1));
	g_buff_grid_markers_ptr = rt_setup.buff_grid_markers_ptr;

	//return;

	// VRS handling preamble. We want to handle edge cases where one of the fine samples is to be traced but others are not,
	// and the sampling position is placed in the center of the coarse sample.
	vec2 frag_offset_for_vrs = vec2(0.0, 0.0);
	bool is_vrs_used_for_a_quad = gl_ShadingRateEXT != 0;
	if (is_vrs_used_for_a_quad)
	{
		ivec2 pos00 = ivec2(gl_FragCoord.xy);
		MetalnessRoughnessMeterialTags mrmt = decode_metalness_roughness_material_tags(imageLoad(imMetalnessRoughnessMaterialTags, pos00));
		//if ((materials.material_properties[mrmt.material_index].flags & MaterialFlag_Reflective) == 0)
		if ((materials.material_properties[mrmt.material_index].flags & rt_setup.material_flags) == 0)
		{
			pos00 = pos00 & ivec2(~1, ~1);
			MetalnessRoughnessMeterialTags mrmt00 = decode_metalness_roughness_material_tags(imageLoad(imMetalnessRoughnessMaterialTags, ivec2(pos00) + ivec2(0, 0)));
			MetalnessRoughnessMeterialTags mrmt10 = decode_metalness_roughness_material_tags(imageLoad(imMetalnessRoughnessMaterialTags, ivec2(pos00) + ivec2(1, 0)));
			MetalnessRoughnessMeterialTags mrmt01 = decode_metalness_roughness_material_tags(imageLoad(imMetalnessRoughnessMaterialTags, ivec2(pos00) + ivec2(0, 1)));
			MetalnessRoughnessMeterialTags mrmt11 = decode_metalness_roughness_material_tags(imageLoad(imMetalnessRoughnessMaterialTags, ivec2(pos00) + ivec2(1, 1)));

			#if 0
			if ((materials.material_properties[mrmt00.material_index].flags & MaterialFlag_Reflective) != 0)
				frag_offset_for_vrs = vec2(-0.5,-0.5);
			if ((materials.material_properties[mrmt10.material_index].flags & MaterialFlag_Reflective) != 0)
				frag_offset_for_vrs = vec2( 0.5,-0.5);
			if ((materials.material_properties[mrmt01.material_index].flags & MaterialFlag_Reflective) != 0)
				frag_offset_for_vrs = vec2(-0.5, 0.5);
			if ((materials.material_properties[mrmt11.material_index].flags & MaterialFlag_Reflective) != 0)
				frag_offset_for_vrs = vec2( 0.5, 0.5);
			#else
			if ((materials.material_properties[mrmt00.material_index].flags & rt_setup.material_flags) != 0)
				frag_offset_for_vrs = vec2(-0.5,-0.5);
			if ((materials.material_properties[mrmt10.material_index].flags & rt_setup.material_flags) != 0)
				frag_offset_for_vrs = vec2( 0.5,-0.5);
			if ((materials.material_properties[mrmt01.material_index].flags & rt_setup.material_flags) != 0)
				frag_offset_for_vrs = vec2(-0.5, 0.5);
			if ((materials.material_properties[mrmt11.material_index].flags & rt_setup.material_flags) != 0)
				frag_offset_for_vrs = vec2( 0.5, 0.5);
			#endif
		}
	}

	ivec2 scaled_sample_pos = ivec2(gl_FragCoord.xy + frag_offset_for_vrs) * rt_setup.screen_sampling_scale;
	ivec2 native_sample_pos =  ivec2(gl_FragCoord.xy + frag_offset_for_vrs);

	//scaled_sample_pos.y = 1080 - scaled_sample_pos.y;
	//native_sample_pos.y = 1080 - native_sample_pos.y;

	// for now also trace to the point we just hit. we would need some way to identify the face we are rasterizing
	// 

	vec3 worldPos = vtx_input.vCameraRelativeWorldPos.xyz;// - transform_params.vCameraPosition.xyz;
	vec3 dir = -normalize(transform_params.vCameraPosition.xyz - worldPos);	// TODO: simplify with the above
	vec3 origin = transform_params.vCameraPosition.xyz;
	vec3 normal;

	int closest_fi = -1;
	int16_t material = int16_t(0);

	// if not running with stencil we simply discard based on material
	MetalnessRoughnessMeterialTags mrmt = decode_metalness_roughness_material_tags(imageLoad(imMetalnessRoughnessMaterialTags, scaled_sample_pos));
	material = int16_t(mrmt.material_index);

	// NOTE: This should be replaced with proper function call, but it also does scaling:(
	vec3 view_direction;
	vec2 vd_pos = vec2(scaled_sample_pos.xy) - rt_setup.frustum_shift.xy * rt_setup.resolution.xy * vec2(0.5, -0.5);
	view_direction.x = -rt_setup.camera_projection_params.z + rt_setup.camera_projection_params.x * vd_pos.x / rt_setup.resolution.x;
	view_direction.y = -rt_setup.camera_projection_params.w + rt_setup.camera_projection_params.y * vd_pos.y / rt_setup.resolution.y;
	view_direction.z = 1.0;

	view_direction.y = -view_direction.y;

	float depth = linearizeDepth(texelFetch(sTextureDepth, native_sample_pos, 0).r);
	vec3 view_coords = positionFromDepth(view_direction, depth);
	view_coords = (rt_setup.mat_model * vec4(view_coords, 1.0)).xyz;

	dir = -normalize(transform_params.vCameraPosition.xyz - view_coords.xyz);

	//outAlbedo.rgb = vec3(fract(view_coords.xyz * 0.1));
	//outAlbedo.rgb = vec3(fract(dir.xyz));
	//outAlbedo.rgb = vec3(TurboColormap(fract(float(closest_fi) * 0.001)));
	//return;

	float closest_it;

	// TODO: optimize with simple length

	vec3 worldNorm;
	{
		closest_it = length(origin - view_coords.xyz);
		uint encoded_normal_material = imageLoad(imNormalMaterial, scaled_sample_pos).r;
		normal = normalize(decode_normal(encoded_normal_material));
		worldNorm = normal;

		//outNormalMaterial = encode_normal_material(normalize(normal), 0);
		//outAlbedo.rgb = normal.rgb * 2.5 + 0.5;
		//return;
	}

	{
		vec3 ro = origin + dir * closest_it - in_bbox_data.bbox_raytrace_min.xyz;
		ivec3 icell = ivec3(floor(ro * GRID_SIZE_RECIP));
		if (icell.x >=0 && icell.y >= 0 && icell.z >= 0 && icell.x < GRID_RES && icell.y < GRID_RES && icell.z < GRID_RES)
		{
			uint icell_idx = icell.z * (GRID_RES * GRID_RES) + icell.y * GRID_RES + icell.x;
			//outAlbedo.rgb = vec3(in_buckets.sizes[icell_idx]) / 10.0;
			//return;
		}
		else
		{

		}
	}

	{
		origin = origin + dir * closest_it;

		// NOTE: Start with albedo color. lighting for the primary hit is calculated in normal lighting pass
		// not here, so it will be multiplied later

		ray_state state;
		ray_traversal_params traversal_params;

		traversal_params.trace_range_primary   = rt_setup.trace_range_primary;
		traversal_params.trace_range_secondary = rt_setup.trace_range_secondary;

		state.normal                 = normal;
		state.material               = material;
		//state.running                = true;
		state.dir                    = dir;
		state.origin                 = origin;
		state.bounces                = int16_t(1);
		//state.early_exit             = false;
		//state.hit                    = false;
		//state.left                   = false;
		//state.inside_transparent     = false;
		state.tests                  = 0;
		state.face_tests             = 0;
		state.active_threads_factor  = 0;
		state.active_threads_samples = 0;
		state.final_color_factor     = 1.0;

		state.flags = 0;
		state.flags = rt_set_mask(state.flags, TF_RUNNING);

		//if (int(gl_FragCoord.x) == 927 && int(gl_FragCoord.y) == 309)
		if (int(gl_FragCoord.x) == 547 && int(gl_FragCoord.y) == 649)
			state.flags = rt_set_mask(state.flags, TF_DEBUG);

		bool pre_step_along_ray    = false;
		bool pre_step_along_normal = !pre_step_along_ray;

		{
			//NOTE: For normal-based pre-stepping we should not use normal which is displaced by the normal map (or is procedural). 
			//This should be a normal derived from a plane we are starting the trace. But this would require actually having such face:(

			if (pre_step_along_normal)
			{
				if (rt_setup.reflect_rays != 0)
					state.origin += vec3(normal) * rt_setup.initial_face_start_distance;
				else
					state.origin += vec3(-normal) * rt_setup.initial_face_start_distance;
			}
			// this is initial sample and this is only place where for now we apply roughness
			if (rt_setup.reflect_rays != 0)
			{
				#if 1
				state.dir = reflect(state.dir, vec3(state.normal));

				// NOTE: This is REALLY costly when divergens goes to hell, so for now because we don't cluster
				// rays just try to limit the roughness...
				state.dir = rt_randomize_dir_for_roughness(ivec2(native_sample_pos), state.dir, state.normal, mrmt.roughness, state.final_color_factor);

				#else
				// NOTE: There is still something broken here:(
				state.dir = rt_reflect_dir_for_roughness_sample_ggx(
					ivec2(native_sample_pos),
					state.dir,
					state.normal,
					materials.material_properties[state.material].roughness,
					state.final_color_factor
				);
				#endif

				if (state.final_color_factor <= 0.0)
					return;
			}
			else
			{
				// NOTE: Add same kind of roughness handling here
				state.dir = glass_refract(state.dir, -state.normal);
				state.flags = rt_set_mask(state.flags, TF_INSIDE_TRANSPARENT);
			}

			dir = state.dir;

			// pre-step using reflected normal
			if (pre_step_along_ray)
				state.origin += state.dir * rt_setup.initial_face_start_distance;

			if (gl_HelperInvocation)
			{
				// just set state to not running, we need the wave active
				state.flags = rt_clear_mask(state.flags, TF_RUNNING);
			}

			findClosestDDAMultibounce(traversal_params, state, closest_fi, closest_fi, closest_it, MAX_BOUNCES);
		}

		//#ifdef VISUALIZE_HEATMAP
		
		outDebug.rgb = vec3(0.0);
		//if (state.face_tests > 0)
		{
			//outDebug.rgb = TurboColormap(float(state.face_tests) / 1024.0f);
			//outDebug.rgb = TurboColormap(float(length(state.origin - origin)) / 512.0f);
			//if (rt_is_mask_set(state.flags, TF_DEBUG))
			//	debugPrintfEXT("Total:%d\\n", state.face_tests);
		}

		outDebug.rgb = vec3(0.0);
		if (state.face_tests > 4096)
		{
			outDebug.rgb = TurboColormap(min(1.0, float(state.tests - 4096) / 1024.0f));
		}
		//color.rgb = TurboColormap(float(state.tests) / 256.0f);
		//color.rgb = TurboColormap(float(ballot_count(state.hit)) /64.0f);

		//#endif

		//color.rgb = TurboColormap(float(state.active_threads_factor) / (state.active_threads_samples * 64));
		//if (state.left)
		//	color.rgb = vec3(0.0, 1.0, 0.0);
#if 0
		int bi = state.bounces;
		if (bi == 0)
			color.rgb = vec3(1.0f);
		if (bi == 1)
			color.rgb = vec3(1.0f, 0.0f, 0.0f);
		if (bi == 2)
			color.rgb = vec3(0.0f, 1.0f, 0.0f);
		if (bi == 3)
			color.rgb = vec3(0.0f, 0.0f, 1.0f);
		if (bi > 3)
			color.rgb = vec3(0.5f, 0.0f, 1.0f);
#endif

		//color.rgb = vec3(state.transparency);

		uint hitPrimitiveId = -1;
		if (rt_is_mask_set(state.flags, TF_HIT))
		{
			hitPrimitiveId = state.user_data.primitiveId;
			outDebug.rgb = TurboColormap(fract(float(state.user_data.primitiveId) / 512.0f));
		}
		//imageStore(imOutputPrimitiveId, scaled_sample_pos, uvec4(hitPrimitiveId));
		//if (rt_is_mask_set(state.flags, TF_EARLY_EXIT))
		//	state.tests = 256;
		//else
		//	state.tests = 0;

		//imageStore(imOutputHitDirectionPrimitiveID, scaled_sample_pos, vec4(state.dir, asfloat(hitPrimitiveId)));
		outHitDirectionPrimitiveID = vec4(state.dir, asfloat(hitPrimitiveId));
	}
}


