//Terrain tessellation domain shader - calculates output positions, light source view positions and performs Sobel normal calculations

Texture2D displacementTexture : register(t0);
SamplerState sampler0 : register(s0);

cbuffer MatrixBuffer : register(b0)
{
	matrix worldMatrix;
	matrix viewMatrix;
	matrix projectionMatrix;
	matrix sunViewMatrix;
	matrix sunProjectionMatrix;
	matrix spotViewMatrix;
	matrix spotProjectionMatrix;
	matrix hillViews[6];
	matrix hillProjections[6];
};

cbuffer SobelBuffer : register(b1)
{
	float sobelScaleFactor;
	float3 padding;
}

struct ConstantOutputType
{
	float edges[4] : SV_TessFactor;
	float inside[2] : SV_InsideTessFactor;
};

struct InputType
{
	float4 position : POSITION;
	float2 tex : TEXCOORD0;
	float3 normal : NORMAL;
};

struct OutputType
{
	float4 position : SV_POSITION;
	float2 tex : TEXCOORD0;
	float3 normal : NORMAL;
	float4 sunViewPos : TEXCOORD1;
	float4 spotViewPos : TEXCOORD2;
	float3 worldPosition : TEXCOORD3;
	float4 hillViewPositions[6] : TEXCOORD4;
};

float SampleDisplacementMap(float2 texCoords)
{
	//Sample the colour at the texCoords of the displacement map texture, return its scaled value
	float4 textureColour = displacementTexture.SampleLevel(sampler0, texCoords, 0);
	return (textureColour - 0.5f) * 150.0f;
}

float3 SobelCalculateNormals(float2 texCoords)
{
	//Get the width and height of the displacement texture, get texel size, sample the neighbours and perform Sobel filtering
	float textureWidth;
	float textureHeight;
	displacementTexture.GetDimensions(textureWidth, textureHeight);
	float2 texelSize = float2(1.0f / textureWidth, 1.0f / textureHeight);

	//Neighbouring pixel offsets
	float2 offsets[8];
	offsets[0] = texCoords + float2(-texelSize.x, -texelSize.y);
	offsets[1] = texCoords + float2(0.0f, -texelSize.y);
	offsets[2] = texCoords + float2(texelSize.x, -texelSize.y);
	offsets[3] = texCoords + float2(-texelSize.x, 0.0f);
	offsets[4] = texCoords + float2(texelSize.x, 0.0f);
	offsets[5] = texCoords + float2(-texelSize.x, texelSize.y);
	offsets[6] = texCoords + float2(0.0f, texelSize.y);
	offsets[7] = texCoords + float2(texelSize.x, texelSize.y);

	//Neighbouring pixel height samples
	float2 samples[8];
	for (int i = 0; i < 8; i++)
	{
		samples[i] = SampleDisplacementMap(offsets[i]);
	}

	//Calculate the Sobel filters using the kernels, apply a constant factor of 2 to X and Y, and a user set factor to Z to control smoothness/roughness
	float filterX = samples[0] - samples[2] + (2.0f * samples[3]) - (2.0f * samples[4]) + samples[5] - samples[7];
	float filterY = samples[0] + (2.0f * samples[1]) + samples[2] - samples[5] - (2.0f * samples[6]) - samples[7];
	float filterZ = sobelScaleFactor * sqrt(max(0.0f, 1.0f - filterX * filterX - filterY * filterY));
	return normalize(float3(2.0f * filterX, filterZ, -2.0f * filterY));
}

//Plane uses quads, set domain to be quad
[domain("quad")]
OutputType main(ConstantOutputType input, float2 uvCoord : SV_DomainLocation, const OutputPatch<InputType, 4> patch)
{
	float3 vertexPosition;
	OutputType output;

	//Calculate the vertex position
	float3 v1 = lerp(patch[0].position, patch[1].position, uvCoord.y);
	float3 v2 = lerp(patch[3].position, patch[2].position, uvCoord.y);
	vertexPosition = lerp(v1, v2, uvCoord.x);

	//Calculate the texture coordinate
	float2 t1 = lerp(patch[0].tex, patch[1].tex, uvCoord.y);
	float2 t2 = lerp(patch[3].tex, patch[2].tex, uvCoord.y);
	output.tex = lerp(t1, t2, uvCoord.x);

	//Displace the vertex by sampling the map at the calculated texture coordinate
	vertexPosition.y += SampleDisplacementMap(output.tex);

	//Use Sobel to calculate the normal after displacement
	output.normal = SobelCalculateNormals(output.tex);
	output.normal = mul(output.normal, (float3x3)worldMatrix);
	output.normal = normalize(output.normal);

	//Calculate the position of the new vertex against the world, view, and projection matrices.
	output.position = mul(float4(vertexPosition, 1.0f), worldMatrix);
	output.position = mul(output.position, viewMatrix);
	output.position = mul(output.position, projectionMatrix);

	//Calculate the position of the sun
	output.sunViewPos = mul(float4(vertexPosition, 1.0f), worldMatrix);
	output.sunViewPos = mul(output.sunViewPos, sunViewMatrix);
	output.sunViewPos = mul(output.sunViewPos, sunProjectionMatrix);

	//Calculate the position of the spot light
	output.spotViewPos = mul(float4(vertexPosition, 1.0f), worldMatrix);
	output.spotViewPos = mul(output.spotViewPos, spotViewMatrix);
	output.spotViewPos = mul(output.spotViewPos, spotProjectionMatrix);

	//Calculate the positions of the hill point light in each of the directions
	for (int i = 0; i < 6; i++)
	{
		output.hillViewPositions[i] = mul(float4(vertexPosition, 1.0f), worldMatrix);
		output.hillViewPositions[i] = mul(output.hillViewPositions[i], hillViews[i]);
		output.hillViewPositions[i] = mul(output.hillViewPositions[i], hillProjections[i]);
	}

	//Output the world position
	output.worldPosition = mul(float4(vertexPosition, 1.0f), worldMatrix).xyz;

	return output;
}

