//AtmosphericReentryShader.fx
//author: Robert Lindner
//Creates a plasma effect around an object
//project: atmospheric entry geometry shader
//http://robert-lindner.com/blog/atmospheric-entry/


//GLOBAL VARIABLES
//****************
float4x4 gWorldViewProj : WORLDVIEWPROJECTION;
float4x4 gWorld: WORLD;
float4x4 gViewInverse: VIEWINVERSE;
float4x4 gLightViewProj;

//Enables
bool gUseDiffuse = false;
bool gUseSpecular = false;
bool gUseNormal = false;
bool gInvertGreen = false;
bool gUseEnvironment = false;
bool gUseFresnel = false;
// Textures
textureCUBE gTextureEnv;
Texture2D gDiffuseMap;
Texture2D gNormalMap;
Texture2D gSpecularMap;

Texture2D gShadowMap;
// Light and Shadow
float3 gLightDirection : DIRECTION = float3(-.577f, -.577f, .577f);
float gShadowMapBias =0.005f;
uint2 gShadowMapSize = uint2(1920, 1080);
float gShadowStrength = 1;
float gSunStrength = 0.8f;
//Atmospheric
float4 gAtmosphereColor = float4(0.5f, 0.65f, 0.9f, 1);
float3 gMoveDir = float3(0, -1, 0);
float gMoveVel = 4000;
float gReentryStrength = 1;
float4 gReentryLightColor = float4(1, 0.5f, 0.2f, 1);
//Ambient parameters
float4 gColorAmbient : COLOR = float4(0.05, 0.05, 0.05, 1.0);
float gAmbientIntensity = 1.0f;
//Diffuse parameters
float4 gColorDiffuse : COLOR = float4(1.0, 1.0, 1.0, 1.0);
//Specular parameters
float4 gColorSpecular : COLOR = float4(1.0, 1.0, 1.0, 1.0);
float gShininess = 50.0f;
//emvironment parameters
float gReflectionStrength = 1.0f;
//cubemap parameters
float gFresnelPower = 1.0f;
float gFresnelMultiplier = 1.0f;
float gFresnelHardness = 1.0f;
float4 gFresnelColor : COLOR = float4(1.0f,1.0f,1.0f,1.0f);
//States 
RasterizerState gRS_NoCulling
{
	CullMode = NONE;
};
BlendState AlphaBlending 
{     
	BlendEnable[0] = TRUE;
	SrcBlend = SRC_ALPHA;
    DestBlend = INV_SRC_ALPHA;
	BlendOp = ADD;
	SrcBlendAlpha = ONE;
	DestBlendAlpha = ZERO;
	BlendOpAlpha = ADD;
	RenderTargetWriteMask[0] = 0x0f;
};
DepthStencilState DisableDepthWriting
{
	DepthEnable = TRUE;
	DepthWriteMask = ZERO;
};
RasterizerState BackCulling
{
	CullMode = BACK;
};
//Texture Samplers
SamplerState gSamplerEnvMap
{
	Filter = MIN_MAG_MIP_LINEAR;
	AddressU = WRAP;
	AddressV = WRAP;
	AddressV = WRAP;
	AddressW = WRAP;
};
SamplerState gDiffuseSampler
{
	Filter = MIN_MAG_MIP_LINEAR;
	AddressU = WRAP;
	AddressV = WRAP;
};
SamplerState gNormalSampler
{
	Filter = MIN_MAG_MIP_POINT;
	AddressU = WRAP;
	AddressV = WRAP;
};
SamplerComparisonState cmpSampler
{
   Filter = COMPARISON_MIN_MAG_MIP_LINEAR;
   AddressU = MIRROR;
   AddressV = MIRROR;
   ComparisonFunc = LESS_EQUAL;
};
//Input and output structires

struct VS_INPUT
{
    float3 Position : POSITION;
    float3 Normal : NORMAL;
	float3 Tangent : TANGENT;
	float2 TexCoord : TEXCOORD0;
};
struct VS_OUTPUT
{
    float4 Position : SV_POSITION;
    float3 Normal : NORMAL;
	float2 TexCoord : TEXCOORD0;
    float3 WorldPosition : TEXCOORD1;
	float4 lpos : TEXCOORD2;
    float3 Tangent : TENGENT;
};

//Helper functions
//*****************
float2 texOffset( int u, int v )
{
    return float2( u * 1.0f/gShadowMapSize.x, v * 1.0f/gShadowMapSize.y );
}
float3 CalculateNormal(float3 n, float3 t, float2 texCoord)
{
	float3 normal = normalize(n);
	if(gUseNormal)
	{
		float3 tangent = normalize(t);
		float3 binormal = normalize(cross(tangent, normal));
		float3x3 localAxis = float3x3(tangent, binormal, normal);
		float3 sampledNormal = ((gNormalMap.Sample( gNormalSampler,texCoord ))*2)-float3(1, 1, 1);
		if(gInvertGreen)sampledNormal.y *= -1;
		normal = normalize(mul(normalize(sampledNormal), (float3x3)localAxis));
	}
	return normal;
}
float CalculateShadow(float4 lpos)
{
	float shadowFactor;
    if( lpos.x < -1.0f || lpos.x > 1.0f ||
        lpos.y < -1.0f || lpos.y > 1.0f ||
        lpos.z < 0.0f  || lpos.z > 1.0f ) 
	{
		shadowFactor = 1;
	}
	else
	{
		lpos.x = lpos.x/2 + 0.5;
		lpos.y = lpos.y/-2 + 0.5;
		lpos.z -= gShadowMapBias;
		float sum;
		float x, y;
		for (y = -2; y <=2; y += 1.0f)
		{
			for (x = -2; x <= 2; x += 1.0f)
			{
				sum += gShadowMap.SampleCmpLevelZero(cmpSampler, lpos.xy + texOffset(x,y), lpos.z );
			}
		}
		shadowFactor = sum / 25.0f;
		shadowFactor = saturate(shadowFactor + 1 - gShadowStrength);
	}
	return shadowFactor;
}
float3 CalculateSkyboxColor(float3 vec)
{
	float3 environment = gReflectionStrength * gTextureEnv.Sample(gSamplerEnvMap,vec);
	float4 gAtmosphereColor = float4(0.5f, 0.5f, 0.9f, 1);
	float atDotNeg  = dot(vec, float3(0, -1, 0));
	float atDotPos  = saturate(-atDotNeg);
	if(atDotNeg>0)
	{
		environment = gAtmosphereColor*atDotNeg;
	}
	float atStrength = pow(1-atDotPos, 15);
	environment += lerp(environment, gAtmosphereColor, atStrength);
	return environment;
}
float3 CalculateFresnel(float3 normal, float3 viewDirection)
{
	float fresnel = 1 - saturate(abs(dot(normal,viewDirection)));
	fresnel = pow(fresnel,gFresnelPower);
	fresnel *= gFresnelMultiplier;
	fresnel *= pow((1 - saturate(dot(float3(0,-1,0),normal))),gFresnelHardness); //Mask
	return fresnel * gFresnelColor * gUseFresnel;
}
float3 CalculateLighting(float3 normal, float3 lDir, float3 viewDirection, float2 texCoord)
{
	//Diffuse
	float diffLightVal = max(dot(-normal, lDir), 0);
	float3 diffuse = gColorDiffuse.rgb*diffLightVal;
	if(gUseDiffuse)diffuse*=gDiffuseMap.Sample(gDiffuseSampler, texCoord);
	
	//Specular
	float3 halfVec = -normalize(viewDirection+lDir);
	float specStrength = pow(max(dot(normal, halfVec), 0), gShininess);
	float3 specular = gColorSpecular.rgb*specStrength;
	if(gUseSpecular)specular*=gSpecularMap.Sample(gDiffuseSampler, texCoord);
	
	return diffuse+specular;
}
float3 CalculatePlanetShine(float3 normal, float3 lDir, float3 atmoCol, float2 texCoord)
{
	//Diffuse
	float diffLightVal = dot(-normal, lDir);
	if(diffLightVal>=0)diffLightVal=pow(diffLightVal, 5);
	else diffLightVal=-pow(-diffLightVal, 0.5f);
	diffLightVal = (diffLightVal+1)*0.5f;
	float3 diffuse = gColorDiffuse.rgb*diffLightVal;
	if(gUseDiffuse)diffuse*=gDiffuseMap.Sample(gDiffuseSampler, texCoord);
	return diffuse*atmoCol;
}
//Normal Pass
//******************
VS_OUTPUT MainVS(VS_INPUT input)
{
    VS_OUTPUT output = (VS_OUTPUT)0;
    output.Position = mul(float4(input.Position,1.0f),gWorldViewProj);
   
    output.Normal = mul(input.Normal,(float3x3)gWorld);
	output.TexCoord = input.TexCoord;
	output.Tangent = mul(normalize(input.Tangent), (float3x3)gWorld);
   
    // Use a float4 for transforming
    output.WorldPosition = mul(float4(input.Position,1.0f),gWorld);
	
    output.lpos = mul(float4(input.Position, 1), mul(gWorld, gLightViewProj));
	
    return output;
}
float4 MainPSBlinn(VS_OUTPUT input) : SV_TARGET 
{
	float3 normal = CalculateNormal(input.Normal, input.Tangent, input.TexCoord);
    float3 viewDirection = normalize(input.WorldPosition.xyz - gViewInverse[3].xyz);
	float3 reflectedVector = normalize(reflect(viewDirection, normal));
	//Environment
	float3 environment = float3(gUseEnvironment, gUseEnvironment, gUseEnvironment);
	environment *= CalculateSkyboxColor(reflectedVector);
	environment *= CalculateFresnel(normal, viewDirection);
	//Lighting
	float3 ambient = gColorAmbient*gAmbientIntensity;
	float shadowFactor = CalculateShadow(float4(input.lpos.xyz/input.lpos.w, input.lpos.w));
	float3 lightingSun = CalculateLighting(normal, gLightDirection, viewDirection, input.TexCoord)*gSunStrength;
	float3 planetShine = CalculatePlanetShine(normal, float3(0, 1, 0), gAtmosphereColor.rgb*0.5f, input.TexCoord);
	float3 reentryLighting = CalculateLighting(normal, gMoveDir, viewDirection, input.TexCoord)*gReentryLightColor.rgb*gReentryLightColor.a;
    return float4(environment+ambient+shadowFactor*lightingSun+planetShine+reentryLighting*gReentryStrength, 1);
}


//Effect Pass
//******************
struct GS_INPUT
{
	float3 Position : POSITION;
	float3 Normal : NORMAL;
	float4 Color: COLOR;
	float4 LightPosition : TEXCOORD2;
	float Size: TEXCOORD0;
};
struct GS_DATA
{
	float4 Position : SV_POSITION;
	float2 TexCoord: TEXCOORD0;
	float4 Color : COLOR;
};
GS_INPUT EffectVS(VS_INPUT input)
{
    GS_INPUT output = (GS_INPUT)0;
   
    output.Position = input.Position;
    output.Normal = mul(input.Normal,(float3x3)gWorld);
    output.LightPosition = mul(float4(input.Position, 1), mul(gWorld, gLightViewProj));
	output.Color = float4(gReentryLightColor.rgb, 0.5f);
	output.Size = 1;
	
    return output;
}
void CreateVertex(inout TriangleStream<GS_DATA> triStream, float3 pos, float2 texCoord, float4 col)
{
	GS_DATA geomData = (GS_DATA) 0;
    geomData.Position = mul( float4(pos,1.0f), gWorldViewProj);
	geomData.TexCoord = texCoord;
	geomData.Color = col;
	triStream.Append(geomData);
}
[maxvertexcount(4)]
void EffectGS(point GS_INPUT vertex[1], inout TriangleStream<GS_DATA> triStream)
{
	//Use these variable names
	float3 topLeft, topRight, bottomLeft, bottomRight;
	float size = vertex[0].Size;
	float3 normal = normalize(vertex[0].Normal);
	float3 origin = vertex[0].Position;
	
	float velDot = dot(normal, gMoveDir);
	if(velDot < 0) size = 0;
	velDot = saturate(velDot);
	velDot = 1 - velDot;
	float length = size * velDot * 10;
	
	//Vertices (Keep in mind that 'origin' contains the center of the quad
	topLeft = float3(-size, length, 0);//      topLeft = mul(topLeft, gViewInverse).rgb;
	topRight = float3(size, length, 0);//      topRight = mul(topRight, gViewInverse).rgb;
	bottomLeft = float3(-size, 0, 0);  //bottomLeft = mul(bottomLeft, gViewInverse).rgb;
	bottomRight = float3(size, 0, 0);  //bottomRight = mul(bottomRight, gViewInverse).rgb;
	//Create Geometry (Trianglestrip)
	CreateVertex(triStream, origin + bottomLeft,  float2(0,1), vertex[0].Color);
	CreateVertex(triStream, origin + topLeft, 	  float2(0,0), vertex[0].Color);
	CreateVertex(triStream, origin + bottomRight, float2(1,1), vertex[0].Color);
	CreateVertex(triStream, origin + topRight, 	  float2(1,0), vertex[0].Color);
}
float4 EffectPSBlinn(GS_DATA input) : SV_TARGET 
{
    return input.Color;
}

//TECHNIQUES
//**********
technique11 TechBlinn 
{
	pass p0 
	{
		SetRasterizerState(gRS_NoCulling);	
		SetVertexShader(CompileShader(vs_4_0, MainVS()));
		SetPixelShader(CompileShader(ps_4_0, MainPSBlinn()));
	}
	pass p1
	{
		SetRasterizerState(gRS_NoCulling);	      
		//SetDepthStencilState(DisableDepthWriting, 0);
        SetBlendState(AlphaBlending, float4( 0.0f, 0.0f, 0.0f, 0.0f ), 0xFFFFFFFF);
		
		SetVertexShader(CompileShader(vs_4_0, EffectVS()));
		SetGeometryShader(CompileShader(gs_4_0, EffectGS()));
		SetPixelShader(CompileShader(ps_4_0, EffectPSBlinn()));
	}
}