#version 460
#extension GL_EXT_debug_printf : enable

#define LIGHT_PROPERTIES_BINDING 1
#define MATERIAL_PROPERTIES_BINDING 1

//#define VISUALIZE_PENUMBRA

#define USE_AMBIENT_OCCLUSION_TERM

uniform sampler2D   sDepth;
uniform usampler2D  sNormalMaterial;
uniform sampler2D   sAlbedo;
uniform sampler2D   sEmissive;
uniform usampler2D  sMetalnessRoughnessMaterialTags;	// metalness roughness material index

uniform sampler2D   sScreenSpaceOcclusion;
uniform sampler2D   sVoxelLighting;
uniform sampler2D   sVoxelOcclusion;

#include <shaders/materials/commons_deferred.glsl>
#include <shaders/materials/commons_gradient.glsl>
#include <shaders/materials/commons.glsl>

struct DeferredRenderLightsParams
{
	vec2  frustum_shift;		// TODO: Remove this. It is still here because of the non-standard way this code is invoked
	vec2  resolution;
	int   lights_num;
	float env_map_intensity;
	float raytrace_scaling_factor;
	float raytrace_strength;
	
};

layout(std140, row_major) uniform BasicDeferredParamsBuffer{
	BasicDeferredParams basic_params;
};

layout(std140, row_major) uniform DeferredRenderLightsParamsBuffer {
	DeferredRenderLightsParams render_lights_params;
};

layout (std140, row_major) uniform DeferredCompositeSetupBuffer {
	DeferredCompositeSetup composite_setup;
};

uniform sampler2D      s_BRDF;
uniform sampler2DArray s_BlueNoise;

in vec2 vTexcoord0;
in vec4 vFrustum;

// output

out vec4 outColor;

#include <shaders/commons_hlsl.glsl>
#include <shaders/materials/commons.glsl>
#include <shaders/deferred/lighting/lighting_support.glsl>

float frustum_far_plane()
{
	return basic_params.camera_near_far_plane.y;
}

float linearizeDepth(float d)
{
	return basic_params.camera_near_far_plane.z / (basic_params.camera_near_far_plane.y - d * basic_params.camera_near_far_plane.w);
}

vec3 positionFromDepth(vec3 vDirection, float depth)
{
	return vDirection.xyz * depth;
}

float sampleShadowDist(in sampler2DShadow smpl, in vec4 coords, out float in_frustum, out vec4 projector_color, out float shadow_dist)
{
	in_frustum = 0.0;

	// no shading if fragment behind
	if (coords.w <= 0.0)
		return 0.0;

	// no shading if fragment outside of light frustum
	if (coords.x < -coords.w || coords.x > coords.w || coords.y < -coords.w || coords.y > coords.w)
		return 1.0;

	in_frustum = 1.0;

	coords.xy = coords.xy * vec2(0.5) + vec2(0.5) * coords.w;

	coords.y = coords.y /coords.w;
	coords.y = 1.0 - coords.y;
	coords.y = coords.y * coords.w;

	vec3 samp = coords.xyz / coords.w;

	float shadow = textureProj(smpl, coords);

	// TODO: Reimplement
	projector_color = vec4(1.0);
	shadow_dist = 0.0;
	//projector_color = texture(LightProjectorSamplers[light.projector_sampler], coords.xy / coords.w);
	//shadow_dist = texture(LightShadowmapCmpSamplers[light.projector_sampler], coords.xy / coords.w).r - coords.z / coords.w;
	return shadow;
}

vec4 sampleProjectorTexture(in sampler2D smpl, in vec4 coords)
{
	vec4 color = vec4(0.0);

	// no shading if fragment behind
	if (coords.w <= 0.0)
		return color;

	// no shading if fragment outside of light frustum
	if (coords.x < -coords.w || coords.x > coords.w || coords.y < -coords.w || coords.y > coords.w)
		return color;

	coords.xy = coords.xy * vec2(0.5) + vec2(0.5) * coords.w;

#ifdef SPIRV_VULKAN
	coords.y = coords.y /coords.w;
	coords.y = 1.0 - coords.y;
	coords.y = coords.y * coords.w;
#endif

	vec3 samp = coords.xyz / coords.w;
	color = texture(smpl, coords.xy / coords.w);
	//shadow_dist = texture(LightShadowmapCmpSamplers[light.projector_sampler], coords.xy / coords.w).r - coords.z / coords.w;
	return color;
}

float calculate_shadow_for_position(in LightProperties light, vec3 world, float depth, out vec4 cascadeColor, out vec4 projector_color)
{
	vec4 vShadowCoords = light.mat_shadow_mvp[0] * vec4(world.xyz, 1.0);
	float in_frustum;
	cascadeColor = vec4(1.0, 0.1, 0.1, 1.0);
	float shadow = sampleShadow(LightShadowmapCmpSamplers[light.shadowmap_sampler0], vShadowCoords, in_frustum, projector_color);
	shadow *= in_frustum;
	return shadow;
}

float calculate_shadow_dist_for_position(in LightProperties light, vec3 world, float depth, out vec4 cascadeColor, out vec4 projector_color, out float shadow_dist)
{
	vec4 vShadowCoords = light.mat_shadow_mvp[0] * vec4(world.xyz, 1.0);
	float in_frustum;
	cascadeColor = vec4(1.0, 0.1, 0.1, 1.0);
	float shadow = sampleShadowDist(LightShadowmapCmpSamplers[light.downsampled_shadowmap_sampler], vShadowCoords, in_frustum, projector_color, shadow_dist);
	shadow *= in_frustum;
	projector_color *= in_frustum;
	return shadow;
}

// ray-cone intersection. could be removed when we finally use the bounding primitive
// code based on https://www.shadertoy.com/view/MtcXWr

struct Cone
{
	float cosa;	// half cone angle
	float h;	// height
	vec3 c;		// tip position
	vec3 v;		// axis
};

struct Ray
{
	vec3 o;		// origin
	vec3 d;		// direction
};

bool inside_light_cone(vec3 p0, vec3 p, float angle, float height)
{
	float hsquared = height*height;
	float cosangle = angle;
	float cDistance = dot(p, p);
	return cDistance<=hsquared && dot(p0, p) >= sqrt(cDistance)*cosangle;
}

bool intersect_cone(Cone s, Ray r, float max_t, out float v, out vec2 intersections)
{
	v = 0.0;
	intersections = vec2(0.0, 0.0);

	bool is_inside = inside_light_cone(s.v, r.o - s.c, s.cosa, s.h);
	vec3 co = r.o - s.c;

	float a = dot(r.d,s.v)*dot(r.d,s.v) - s.cosa*s.cosa;
	float b = 2. * (dot(r.d,s.v)*dot(co,s.v) - dot(r.d,co)*s.cosa*s.cosa);
	float c = dot(co,s.v)*dot(co,s.v) - dot(co,co)*s.cosa*s.cosa;

	float det = b*b - 4.*a*c;
	if (det < 0.)
	{
		//v = 0.7;
		return false;
	}

	det = sqrt(det);
	float t1 = (-b - det) / (2. * a);
	float t2 = (-b + det) / (2. * a);

	intersections = vec2(t1, t2);
	if (is_inside)
		return true;

	// This is a bit messy; there ought to be a more elegant solution.
	float t = t1;
	//	if (t < 0.) t = t2;

	if (t < 0. || t2 > 0. && t2 < t) t = t2;
	//if (t < 0. || t2 > t) t = t2;		// we actualy test for further intersection
	//if (t < 0.) return false;

	if (t < 0.0)
		return false;

	vec3 cp = r.o + t*r.d - s.c;
	float h = dot(cp, s.v);
	if (h < 0.0)
	{
		// can happen if nearest intersection is for the 'negative cone'
		cp = r.o + max(t1, t2)*r.d - s.c;
		h = dot(cp, s.v);
		if (h < 0.0)
			return false;
	}
	if (h < s.h)
		return true;
	
	// check for far intersection if exists
	{
		cp = r.o + max(t1, t2)*r.d - s.c;
		h = dot(cp, s.v);
		if (h > 0.0 && h < s.h)
			return true;
	}

	//v = 0.1;
	return false;
	
	// cap the cone with a plane. NOTE: not needed because we also check the further intersection which seems to be enough
#if 0
	{
		vec3 P0 = s.c + s.h * s.v;
		float d = dot(normalize(s.v), r.d);
		if (d > 0.0)
			return false;

		//float tp = -(dot(r.o, normalize(s.v)) + length(P0)) / d;
		float tp = -(dot(r.o - P0, normalize(s.v))) / d;

		if (t1 > t2)
		{
			float st = t1;
			t1 = t2;
			t2 = st;
		}

		if (tp < t1 || tp > t2)
			return false;
	}
	return true;
#endif
}


#include <shaders/materials/noise/noise3d.glsl>

// all the positions are in world coords
vec3 calculate_lighting_world(LightProperties light, in vec3 pos, in vec3 normal, in vec3 light_pos, in vec3 cam_pos, in float NdotL)
{
	float d = NdotL;
	if (d < 0.0)
		d = 0.0;
	
	vec3 specular = vec3(0.0);
	if (d > 0.0)
		specular = pow(max(0.0, dot(reflect(normalize(light_pos - pos), normalize(normal)), -normalize(cam_pos - pos))), 14.0) * light.diffuse.xyz;

	return vec3(vec3(d) * light.diffuse.xyz + specular);
}

vec3 get_view_direction(vec2 screen_pos)
{
	vec2 vd_pos = screen_pos - render_lights_params.frustum_shift.xy * render_lights_params.resolution.xy * vec2(0.5, -0.5);
	vec3 view_direction;

	view_direction.x = -basic_params.camera_projection_params.z + basic_params.camera_projection_params.x * vd_pos.x / render_lights_params.resolution.x;
	view_direction.y = -basic_params.camera_projection_params.w + basic_params.camera_projection_params.y * vd_pos.y / render_lights_params.resolution.y;
	view_direction.z = 1.0;

	#ifdef SPIRV_VULKAN
	view_direction.y = -view_direction.y;
	#endif

	return view_direction;
}

float volumetric_sample_shadow_spot(in LightProperties light, ivec2 screen_pos, vec2 jitterHash, in vec3 frustum, in float max_depth, out float attenuation, out vec4 projector_color)
{
	float fact = 0.0;
	float depth_dir = 0.1;	// this is completely arbitrary...
	float depth_pos = 0.0;

	vec4 cascade_color;

	attenuation = 0.0;

	// jitter for banding removal. this is currently also completely arbitrary....
	depth_pos = depth_pos + depth_dir * jitterHash.x * 1000.0;

	int i;

	for(i=0; i<256;i++)
	{
		//vec3 ray_world = positionFromDepth(frustum, depth_pos);
		vec3 view_direction = get_view_direction(vec2(screen_pos));
		vec3 ray_world = (basic_params.mModel * vec4(positionFromDepth(view_direction, depth_pos), 1.0)).xyz;
		float lf = 0.0;

		float falloff = dot(light.direction.xyz, normalize(ray_world - light.position.xyz));

		if (falloff > light.cutoff)
		{
			float sample_attenuation = 1.0 - (1.0 - falloff) / (1.0 - light.cutoff);
			sample_attenuation *= 1.0 - clamp(length(light.position.xyz - ray_world.xyz) / light.range, 0.0, 1.0);
			sample_attenuation = pow(sample_attenuation, 3.0);

			vec4 projector;
			float shadow_dist;
			float shadow_value;
			shadow_value = calculate_shadow_dist_for_position(light, ray_world.xyz, depth_pos, cascade_color, projector, shadow_dist);
			lf = sample_attenuation * shadow_value;
			attenuation += sample_attenuation;
			
			// maybe just consider using the ground level??
			#if 0
			if (dot(projector.rgb, projector.rgb) > 0.0 && shadow_dist > 0.0)
			{
				// distance to the projection plane
				//float threshold = 0.0028;
				const float threshold = 0.0038;
				const float min_shadow_dist_factor = 0.001;
				if (shadow_dist < threshold)
				{
					float noise = snoise(ray_world.xyz * 2.0 + vec3(globals.global_time, 0.0, -gTime * 0.3));
					noise += snoise(ray_world.xyz * 1.10231 + vec3(-globals.global_time * 0.21, 3.021, -gTime * 0.031));
					//float noise = 1.0;

					float f = (threshold - shadow_dist) * (1.0 / threshold);
					f = f * f;	// not really needed? -- kiero
					shadow_dist = f * (0.5 + 0.5 * noise);
				}
				else
					shadow_dist = 0.0;

				projector_color.rgb += sample_attenuation * projector.rgb * (min_shadow_dist_factor + (1.0 - min_shadow_dist_factor) * shadow_dist);
			}
			#else
			projector_color.rgb += projector.rgb * (1.0 - shadow_value) * sample_attenuation;
			#endif
		}

		fact += lf;

		depth_dir = depth_dir * 1.03;
		depth_pos += depth_dir;

		if (depth_pos > max_depth)
			break;
	}

	projector_color.rgb /= (float(i + 1));
	return (attenuation - fact) / (float(i + 1));
}

float fog_density_for_position(vec3 p)
{
	return exp(-p.y / composite_setup.volumetric_light_fog_height);
}

float volumetric_sample_shadow_all_lights(ivec2 screen_pos, vec2 jitterHash, in vec3 frustum, in float frustum_depth, in float max_depth, out float attenuation, out vec3 color)
{
	float fact = 0.0;
	float t = 0.0;

	vec4 cascade_color;

	attenuation = 0.0;
	color = vec3(0.0);

	float fog_density = composite_setup.volumetric_light_fog_density;
	float scattering_shadow_falloff_distance = 1400.0;

	int samples_num = 0;
	const int STEPS = 192;

	// jitter for banding removal. this is currently also completely arbitrary....
	t = t + jitterHash.x * (1.0 / float(STEPS));

	for(int i = 0; i < STEPS; i++)
	{
		//vec3 ray_world = positionFromDepth(frustum, depth_pos);
		vec3 view_direction = get_view_direction(vec2(screen_pos));

		float view_depth_step = 1.0 / float(STEPS);
		float view_depth = t * frustum_depth;
		if (view_depth > max_depth)
			break;

		vec3 ray_world = vector_transform_by_mat43(positionFromDepth(view_direction, view_depth), basic_params.mModel).xyz;
		float lf = 0.0;

		float local_fog_density = fog_density * fog_density_for_position(ray_world);

		for(int light_idx = 0; light_idx < render_lights_params.lights_num; light_idx++)
		{
			LightProperties light = lights.light_properties[light_idx];
			vec3 light_color = vec3(.0);

			float falloff = dot(light.direction.xyz, normalize(ray_world - light.position.xyz));

			if ((light.type & LightType_Volumetric) == 0)
				falloff = 0.0;

			if (falloff > light.cutoff)
			{
				vec3 light_attenuation_color = light_calculate_spot_attenuation_color(light, ray_world.xyz);

				vec4 projector = vec4(1.0);

				vec4 vShadowCoords = light.mat_shadow_mvp[0] * vec4(ray_world.xyz, 1.0);
				bool is_in_frustum = true;
				
				if (vShadowCoords.w <= 0.0)
					is_in_frustum = false;
				else
				{
					// no shading if fragment outside of light frustum
					if (vShadowCoords.x < -vShadowCoords.w || vShadowCoords.x > vShadowCoords.w || vShadowCoords.y < -vShadowCoords.w || vShadowCoords.y > vShadowCoords.w)
						is_in_frustum = false;
				}

				float shadow_value = 0.0;
				if (is_in_frustum)
				{
					vec4 sample_coords = vShadowCoords;
					sample_coords.xy = sample_coords.xy * vec2(0.5) + vec2(0.5) * sample_coords.w;

					sample_coords.y = sample_coords.y /sample_coords.w;
					sample_coords.y = 1.0 - sample_coords.y;
					sample_coords.y = sample_coords.y * sample_coords.w;

					shadow_value = textureProj(LightShadowmapCmpSamplers[light.downsampled_shadowmap_sampler], sample_coords);

					// NOTE: In theory, we would like to have scattering support, so some distance behind the shadow
					// the shadowing should fade out. This is obviously fake, but maybe good to have smoother cutouts?
					{
						//float f = shadow_dist / scattering_shadow_falloff_distance;
						//f = min(f, 1.0);
						//shadow_value *= (1.0 - f);
					}

					light_attenuation_color *= (1.0 - shadow_value);

					if ((light.type & LightType_Projector) != 0)
					{
						// NOTE: We should make the LOD controllable. For now make it arbitrary (low)
						// which should be fine for imported texres, and the generated ones should
						// be low res
						vec2 projected_sample_coords = sample_coords.xy / sample_coords.w;
						projector.rgb = textureLod(LightProjectorSamplers[light.projector_sampler], projected_sample_coords, 1.5).rgb * light.projector_intensity;
					}
				}

				lf += (1.0 - shadow_value);

				{
					float af = 1.0 - exp(-local_fog_density * view_depth_step);
					//float af = clamp(1.0 - pow(attenuation, 0.2), 0.0, 1.0);
					color.rgb += light.diffuse.rgb * light.intensity * projector.rgb * light_attenuation_color * af * 1.0;
				}
				samples_num += 1;
			}
		}

		//attenuation += fog_density * view_depth_step;
		fact += lf * local_fog_density * view_depth_step;

		t = t + view_depth_step;
	}

	//color.rgb /= (float(samples_num + 1));
	return fact / (float(STEPS) * render_lights_params.lights_num);
}

void main() {
	ivec2 screen_pos = ivec2(gl_FragCoord.xy);

	vec4  base_color              = texelFetch(sAlbedo, screen_pos, 0);
	uint  encoded_normal_material = texelFetch(sNormalMaterial, screen_pos, 0).r;
	vec3  vNorm                   = decode_normal(encoded_normal_material);
	int   materialId              = decode_material(encoded_normal_material);

	float depth                   = linearizeDepth(texelFetch(sDepth, screen_pos, 0).r);
	vec3  view_direction          = get_view_direction(vec2(screen_pos));
	//vec3  world                   = basic_params.camera_position.xyz + positionFromDepth(vFrustum.xyz, depth);
	vec3  world                   = (basic_params.mModel * vec4(positionFromDepth(view_direction, depth), 1.0)).xyz;

	//outColor.rgb = fract(world * 0.01);
	//return;

	outColor = vec4(0.0);

	MetalnessRoughnessMeterialTags metalness_roughness_material_tags;
	metalness_roughness_material_tags = decode_metalness_roughness_material_tags(texelFetch(sMetalnessRoughnessMaterialTags, screen_pos, 0).rgba);

	float metalness = metalness_roughness_material_tags.metalness;
	float roughness = metalness_roughness_material_tags.roughness;
	uint material   = metalness_roughness_material_tags.material_index;

	uint material_flags = materials.material_properties[material].flags;
	// also check material overrides. only include 4 attributes
	if ((metalness_roughness_material_tags.material_flag_overrides & MaterialFlag_OverrideFlags) != 0)
	{
		material_flags &= ~0xf;
		material_flags |= metalness_roughness_material_tags.material_flag_overrides & 0xf;
	}

	if ((material_flags & MaterialFlag_DisableLighting) != 0)
	{
		outColor = base_color;
		outColor.a = 0.0;
		return;
	}

	// NOTE: because the model of PBR we use, and it allowing for 0...1 value for base color it doesn't play
	// well with our totally mixed up pipe. For RT we do a hacky job and use color and magnitude separately.
	// How this is going to play? No idea

	vec3 base_emissive = texelFetch(sEmissive, screen_pos, 0).rgb * materials.material_properties[material].emissive.rgb;
	bool is_background = (materialId & MATERIAL_ID_MASK_ATTR) == ATTR_BACKGROUND;
	bool is_particle   = (material_flags & (MaterialFlag_ParticleLighting)) != 0;

	vec3 view = normalize(basic_params.camera_position.xyz - world.xyz);
	vec3 outLightColor = vec3(0.0);
	float outOpacity = 0.0;

#ifdef VISUALIZE_PENUMBRA
	float vis_penumbra = 0.0;
#endif

	bool use_all_lights = true;

	if (use_all_lights == true)
	{
		vec2 jitterHash = texelFetch(s_BlueNoise, ivec3(screen_pos.xy & ivec2(127), (globals.monotonic) & 15), 0).rg;
		float attenuation = 0.0;
		vec3 volume_color = vec3(0.0);
		float volume_attenuation = volumetric_sample_shadow_all_lights(screen_pos, jitterHash, vFrustum.xyz, frustum_far_plane(), depth, attenuation, volume_color);
		//float volume_attenuation = volumetric_sample_shadow(light, screen_pos, vFrustum.xyz, depth, attenuation, projector_color);
		//light_color.a = projector_color.r * volume_attenuation;
		outLightColor.rgb = volume_color;
		outOpacity = volume_attenuation;

	}

	if (use_all_lights == false)
	{
		for(int light_idx = 0; light_idx < render_lights_params.lights_num; light_idx++)
		{
			LightProperties light = lights.light_properties[light_idx];
			if ((light.lighting_exclusion_tags & metalness_roughness_material_tags.component_tags) != 0)
				continue;

			vec3 light_color = vec3(.0);

			// remove when we use the proxy object... -- kiero
			// NOTE: also, dont check for non-volumetric lights

			if ((light.type & LightType_Spot) != 0)
			{
				if ((light.type & LightType_Volumetric) != 0)
				{
					Cone cone = Cone(light.cutoff * 1.0, light.range * 2.0, light.position.xyz, light.direction.xyz);
					Ray ray = Ray(basic_params.camera_position.xyz, normalize(vFrustum.xyz));
					float max_t = depth;
					float vv = 0.0;
					vec2 intersections;
					//if (intersect_cone(cone, ray, max_t, vv, intersections) == false)
					//	continue;
				}
			}

			vec3 pointToLight = light.position.xyz - world.xyz;
			if ((light.type & LightType_Directional) != 0)
				pointToLight = -light.direction.xyz;

			[[branch]]
			if ((light.type & LightType_Volumetric) != 0)
			{
				// experimental lightshafts (volumetric)
				vec2 jitterHash = texelFetch(s_BlueNoise, ivec3(screen_pos.xy & ivec2(127), (globals.monotonic + light_idx) & 15), 0).rg;
				float attenuation = 0.0;
				vec4 projector_color = vec4(0.0);
				float volume_attenuation = volumetric_sample_shadow_spot(light, screen_pos, jitterHash, vFrustum.xyz, depth, attenuation, projector_color);
				//float volume_attenuation = volumetric_sample_shadow(light, screen_pos, vFrustum.xyz, depth, attenuation, projector_color);
				//light_color.a = projector_color.r * volume_attenuation;
				light_color.rgb = vec3(light.intensity * light.diffuse.rgb * volume_attenuation);
				//light_color.a = pow(volume_attenuation * light.diffuse.a, 2.0);

				if (false)
				{
					Cone cone = Cone(light.cutoff * 1.0, light.range * 2.0, light.position.xyz, light.direction.xyz);
					Ray ray = Ray(basic_params.camera_position.xyz, normalize(vFrustum.xyz));
					float max_t = depth;

					float vv = 0.0;
					vec2 intersections;
					if (intersect_cone(cone, ray, max_t, vv, intersections) == true)
					{
						if (intersections.y < intersections.x)
						{
							float t = intersections.x;
							intersections.x = intersections.y;
							intersections.t = t;
						}

						if (intersections.x < 0.0)
							intersections.x = 0.0;

						//intersections.x = min(max_t, intersections.x);
						//intersections.y = min(max_t, intersections.y);
						
						//outColor.a = (intersections.y - intersections.x) * 0.000001;
						//outColor.rgb = vec3(intersections.y - intersections.x) * 0.0001;
						//outColor.a = 0.5;
						//outColor.rgb = vec3(0.5);
					}

					//outColor.rgb = vec3(max_t) * 0.00005;
					//outColor.a = 1.0;
					//return;
				}

				outLightColor.rgb += light_color.rgb;
			}

		}
	}

	outColor = vec4(outLightColor.rgb * 0.1, outOpacity);
}

