#include "pch.h"
#include "TaskManager.h"
#include "Components/RenderComponent.h"
#include "Engine/Engine.h"
#include "Game/Game.h"

//Function called by a scene when it gets loaded as the active scene, it initializes all keys, counters and sets the reference to the component map
//This map is used to create and execute all tasks
void anx::TaskManager::SetComponentMap(std::map>& componentMap)
{
	m_NrThreads = ThreadPool::GetInstance().GetNrThreads();

	m_pComponentMap = &componentMap;	
	
	m_Keys.clear();
	m_Counters.clear();
	m_LockedKeys.clear();
	for(auto component: *m_pComponentMap)
	{
		if(!component.second.front()->IsLockedToMainThread())
		{
			m_Keys.push_back(component.first);
			m_Counters[component.first].store(0);
		}
		else
		{
			m_LockedKeys.push_back(component.first);
		}

	}
}

bool anx::TaskManager::GetJobRequest(Task& task, bool mainThread)
{
	std::unique_locklock(m_TaskMutex);
	//If the engine is notified to quit, make the worker thread exit the loop
	if (m_Quit) return false;

	if (m_Tasks[int(m_CurrentStage)].try_pop(task))
		return true;

	//If the main thread wanted to get a job, but all jobs were completed, make sure to notify a worker thread to create new tasks
	if (mainThread)
	{
		m_TaskCondition.notify_one();
		return false;
	}

	//When there are no tasks available, wait for more to become available if not al tasks have finished executing
	if (!AllCountersZero())
	{
		m_TaskCondition.wait(lock);
		return false;
	}

	m_NextStage = Stage((int(m_CurrentStage) + 1) % 5);
	m_CurrentStage = Stage::SIZE;

	//Reset the state of the task queue
	ResetCounters();
	ResetFlags();

	if (m_NextStage == Stage::SWAPPF)
	{
		//Wait for the main thread to complete drawing before swapping the buffers
		Engine::GetInstance().SyncWithMain();
	}
	if (m_NextStage == Stage::UPDATE)
	{
		//Wait for the main thread to complete handling input and cleaning resources before starting the next frame
		Engine::GetInstance().SyncWithMain();
		lock.unlock();
		m_CurrentStage = m_NextStage;
		//Create frame tasks for the new frame
		CreateFrameTasks();
	}


	//Notify all waiting threads new tasks have been added, once possible sync with main happened
	if(m_CurrentStage != Stage::UPDATE)
	{
		m_CurrentStage = m_NextStage;
	}

	m_TaskCondition.notify_all();
	return false;
}

//Function used by the main thread to handle tasks that require SDL function calls
bool anx::TaskManager::GetLockedJobRequest(Task & task)
{
	if (m_Quit) return false;

	if (m_LockedTasks.try_pop(task))
		return true;

	m_TaskCondition.notify_one();
	return false;
}

//Used by the input manager to add a command job to the job queue
void anx::TaskManager::AddCommandJob(Command * pToAdd)
{
	if (!pToAdd) return;
	//Create a new task with the command bound
	Task newTask;
	newTask.stage = Stage::COMMAND;
	newTask.pCommand = pToAdd;
	//Add the command to the task queue
	m_Tasks[int(Stage::UPDATE)].push(newTask);
	//Notify the condition variable that there is a new task available
	m_TaskCondition.notify_one();
}

//Used by the scene to add a request to load a new scene
void anx::TaskManager::AddLoadingJob(Scene * pToLoad)
{
	if (!pToLoad) return;
	//Create a new task with the command bound
	Task newTask;
	newTask.stage = Stage::LOADING;
	newTask.pScene = pToLoad;
	//Add the command to the task queue
	m_Tasks[int(Stage::UPDATE)].push(newTask);
	//Notify the condition variable that there is a new task available
	m_TaskCondition.notify_one();
}

//Can be called from anywhere, usually the constructor of a component
//makes sure all components of type base get updated before tasks of dependantOn type get executed
void anx::TaskManager::SetDependecny(size_t base, size_t dependantOn)
{
	m_Dependencies[base].push_back(dependantOn);
}

anx::Stage anx::TaskManager::GetCurrentStage() const
{
	return m_CurrentStage;
}


//This function has to get called after a certain task group has completed
void anx::TaskManager::CheckDependencies(size_t componentHash)
{
	for (size_t i = 0; i < m_Keys.size(); ++i)
	{
		size_t key = m_Keys.at(i);
		//If this component type has already been marked updated, ignore it
		if (m_ComponentUpdated[key])
			continue;
		auto result = std::find(m_Dependencies[key].begin(), m_Dependencies[key].end(), componentHash);
		//If this key was depending on the completed taskset, try to create the task set for the depending key
		if(result != m_Dependencies.at(key).end())
		{
			CreateTasks(key, m_CurrentStage);
		}
	}
}

//Clears the counters for all component types that don't have to be updated this frame
//Also marking these components complete and checks for dependant components
void anx::TaskManager::ClearEmptyCounters()
{
	for (size_t i = 0; i < m_Keys.size(); ++i)
	{
		size_t componentHash = m_Keys[i];
		if (!m_pComponentMap->at(componentHash).at(0)->ShouldExecuteStage(m_NextStage))
		{
			m_Counters[componentHash] = 0;
			m_ComponentUpdated[componentHash] = true;
			CheckDependencies(componentHash);
		}
	}
	for (size_t i = 0; i < m_LockedKeys.size(); ++i)
	{
		size_t componentHash = m_LockedKeys[i];
		if (!m_pComponentMap->at(componentHash).at(0)->ShouldExecuteStage(m_NextStage))
		{
			m_Counters[componentHash] = 0;
			m_ComponentUpdated[componentHash] = true;
			CheckDependencies(componentHash);
		}
	}
}

//Sets all counters to the number of the corresponding components in the current scene
void anx::TaskManager::ResetCounters()
{
	if(m_NextStage == Stage::PHYSICSINIT)
	{
		size_t nrColliders = m_pComponentMap->at(m_ColliderHash).size();
		m_Counters[m_ColliderHash] = nrColliders;
	}
	else if (m_NextStage == Stage::PHYSICSMAIN)
	{
		CreateBaseTasks(Stage::PHYSICSMAIN);
		auto* physScene = Engine::GetInstance().GetCurrentGame()->GetCurrentScene()->GetPhysicsScene();
		int size = int(physScene->GetTree()->GetNotEmptyNodes().size());
		m_Counters[m_ColliderHash] = size;
	}
	else
	{
		for (size_t i = 0; i < m_Keys.size(); ++i)
		{
			size_t index = m_Keys.at(i);
			m_Counters.at(index) = int(m_pComponentMap->at(index).size());
		}
		ClearEmptyCounters();
	}
}

//Creates the base tasks for all component types in the current scene
void anx::TaskManager::CreateBaseTasks(Stage stage)
{
	switch (stage)
	{
	case Stage::UPDATE:
	case Stage::LATEUPDATE:
	case Stage::SWAPPF:
		//Try to create tasks for all component group types
		for (size_t i = 0; i < m_Keys.size(); ++i)
		{
			CreateTasks(m_Keys.at(i), stage);
		}

		//Create locked tasks for all locked task groups
		for (size_t i = 0; i < m_LockedKeys.size(); ++i)
		{
			CreateLockedTasks(m_LockedKeys.at(i), stage);
		}
		break;
	case Stage::PHYSICSINIT:
	{
		auto* physScene = Engine::GetInstance().GetCurrentGame()->GetCurrentScene()->GetPhysicsScene();
		physScene->StartUpdate();

		size_t nrColliders = m_pComponentMap->at(m_ColliderHash).size();
		size_t collidersPerTask = std::max(1, int(nrColliders / m_NrThreads));

		size_t collidersHandled = 0;

		while(collidersHandled < nrColliders)
		{
			Task newTask;
			newTask.key = m_ColliderHash;
			newTask.range = std::make_pair(collidersHandled, std::min(collidersHandled + collidersPerTask, nrColliders));
			collidersHandled += collidersPerTask;
			newTask.stage = stage;
			m_Tasks[int(stage)].push(newTask);
		}
		break;
	}
	case Stage::PHYSICSMAIN:
		{
			auto* physScene = Engine::GetInstance().GetCurrentGame()->GetCurrentScene()->GetPhysicsScene();

			int size = int(physScene->GetTree()->GetNotEmptyNodes().size());

			int nrNodesPerTask = std::max(1,int(size / m_NrThreads));
			int nrNodesHandled = 0;

			while (nrNodesHandled < size)
			{
				Task newTask;
				newTask.key = m_ColliderHash;
				newTask.range = std::make_pair(nrNodesHandled, std::min(nrNodesHandled + nrNodesPerTask, size));
				nrNodesHandled += nrNodesPerTask;
				newTask.stage = stage;
				m_Tasks[int(stage)].push(newTask);
			}
		}
		break;
	default: ;
	}
}

//Reduces the component counter for the given component type by the given amount
//If the counter becomes 0 as a result, it flages the component as completed and checks for dependant component types
void anx::TaskManager::ReduceCounter(size_t componentHash, int toReduce)
{
	m_Counters.at(componentHash) -= toReduce;
	if (m_Counters.at(componentHash) == 0)
	{
		m_ComponentUpdated[componentHash] = true;
		CheckDependencies(componentHash);
	}
}

std::vector& anx::TaskManager::GetComponents(size_t componentHash)
{
	return m_pComponentMap->at(componentHash);
}

void anx::TaskManager::Quit()
{
	m_Quit = true;
	m_TaskCondition.notify_all();
}

anx::TaskManager::TaskManager()
	:m_CurrentStage(Stage::SWAPPF), m_pComponentMap(nullptr), m_RenderHash(ComponentID::GetID())
	,m_ColliderHash(ComponentID::GetID())
{
}

//Function that created the actual tasks per component type and stage
void anx::TaskManager::CreateTasks(size_t componentHash, Stage stage)
{
	int size = int(m_pComponentMap->at(componentHash).size());

	if (!m_pComponentMap->at(componentHash).at(0)->ShouldExecuteStage(stage))
	{
		return;
	}

	std::vector& dependencies = m_Dependencies[componentHash];
	for (int i = 0; i < int(dependencies.size()); ++i)
	{
		//If one of the counter of the dependant task lists is not yet 0, don't create the tasks.
		if (!m_Counters.at(dependencies.at(i)) == 0)
			return;
	}

	int nrObjectsPerTask = std::max(int(size / m_NrThreads), 10);
	int nrObjectsHandled = 0;

	while (nrObjectsHandled < size)
	{
		Task newTask;
		newTask.key = componentHash;
		newTask.range = std::make_pair(nrObjectsHandled, std::min(nrObjectsHandled + nrObjectsPerTask, size));
		nrObjectsHandled += nrObjectsPerTask;
		newTask.stage = stage;
		m_Tasks[int(stage)].push(newTask);
	}
	
	m_TaskCondition.notify_one();
}

//Same as previous function, but for the tasks locked to the main thread
void anx::TaskManager::CreateLockedTasks(size_t componentHash, Stage stage)
{
	if(!m_pComponentMap->at(componentHash).at(0)->ShouldExecuteStage(stage))
	{
		return;
	}

	int size = int(m_pComponentMap->at(componentHash).size());

	std::vector& dependencies = m_Dependencies[componentHash];
	for (int i = 0; i < int(dependencies.size()); ++i)
	{
		//If one of the counter of the dependant task lists is not yet 0, don't create the tasks.
		if (!m_Counters.at(dependencies.at(i)) == 0)
			return;
	}
	//If the number of components exceed the maximum number of components per task, create sets with the max size
	if (size > m_MaxTaskSize)
	{
		for (int i = 0; i < size; i += m_MaxTaskSize)
		{
			Task newTask;
			newTask.key = componentHash;
			newTask.range = std::make_pair(i, std::min(i + m_MaxTaskSize, size));
			newTask.stage = stage;
			m_LockedTasks.push(newTask);
		}
	}
	else
	{
		Task newTask;
		newTask.key = componentHash;
		newTask.range = std::make_pair(0, size);
		newTask.stage = stage;
		m_LockedTasks.push(newTask);
	}
	m_TaskCondition.notify_one();
}

//Base Function that starts the process of creating al frame tasks
void anx::TaskManager::CreateFrameTasks()
{
	for (int i = 0; i < int(Stage::SIZE); ++i)
	{
		if (i == int(Stage::PHYSICSMAIN)) ++i;
		CreateBaseTasks(Stage(i));
	}
}

//Function that checks if all components have finished updating
bool anx::TaskManager::AllCountersZero()
{
	//std::unique_locklock(m_CounterMutex);
	for (size_t i = 0; i < m_Keys.size(); ++i)
	{
		size_t key = m_Keys.at(i);
		if(m_Counters[key] > 0)
			return false;
	}
	return true;
}

//Resets the completed flag per component type
void anx::TaskManager::ResetFlags()
{
	for (size_t i = 0; i < m_Keys.size(); ++i)
	{
		m_ComponentUpdated[m_Keys.at(i)] = false;
	}
}