/* Copyright
 * ========================================================================================
 * Project:		Accurate DRAM Model
 * Author:		Nan Li, KTH
 * ID:			adm_scheduler.h, v1.0, 2010/12/04
 *
 * Description:	Defines the DRAM scheduler class
 *
 * ========================================================================================
 * Version History
 * ========================================================================================
 * Version 1.0: Supports three basic scheduling policy: FIFO, priority, and round-robin.
 * Version 0.1: Draft version
 * ========================================================================================
 */

#ifndef ADM_SCHEDULER_H_
#define ADM_SCHEDULER_H_

#include "systemc.h"
#include "ocpip.h"
#include <map>
#include <deque>

template <unsigned int BUSWIDTH = 32>
class adm_scheduler : public sc_module {
public:
	typedef typename ocpip::ocp_data_class_unsigned<BUSWIDTH, 32>::DataType Td;
	enum scheduling_policy {
		FIFO,
		PRIORITY,
		ROUND_ROBIN
	};
	typedef std::map<unsigned int, std::deque<tlm::tlm_generic_payload *> >::iterator QueueIter;

	sc_in_clk clk;
	ocpip::ocp_master_socket_tl1<sizeof(Td) * 8> ocpInitPort;
	ocpip::ocp_slave_socket_tl1<sizeof(Td) * 8> ocpTargetPort;

	/* constructor */
	SC_HAS_PROCESS(adm_scheduler);
	adm_scheduler(const sc_module_name &nm, scheduling_policy pol = PRIORITY, uint64 rrSlotSize = 0)
			: sc_module(nm),
			  clk("clk"),
			  ocpInitPort("init_port", ocpip::ocp_master_socket_tl1<sizeof(Td)*8>::mm_txn_with_data()),
			  ocpTargetPort("target_port"),
			  m_reqSrc(0),
			  m_reqDst(0),
			  m_rspSrc(0),
			  m_rspInt(0),
			  m_rspDst(0),
			  m_policy(pol),
			  m_cycle(0),
			  m_roundRobinSlotSize(rrSlotSize),
			  m_roundRobinNextSlotHop(0)
	{
		/* register transport functions */
		ocpInitPort.register_nb_transport_bw(this, &adm_scheduler::nb_transport_bw);
		ocpInitPort.activate_synchronization_protection();
		ocpTargetPort.register_nb_transport_fw(this, &adm_scheduler::nb_transport_fw);
		ocpTargetPort.activate_synchronization_protection();

		/* process for forward path */
		SC_METHOD(proc_fw);
		sensitive << clk.pos();
		dont_initialize();

		/* process for backward path */
		SC_METHOD(proc_bw);
		sensitive << clk.pos();
		dont_initialize();

		/* thread for cycle counting */
		SC_THREAD(countCycle);
		sensitive << clk.pos();
		dont_initialize();

		/* initialization of round robin pointer */
		if (m_policy == ROUND_ROBIN)
			m_roundRobinIter = m_reqQueues.end();
	}

private:
	tlm::tlm_sync_enum nb_transport_fw(tlm::tlm_generic_payload& txn, tlm::tlm_phase& ph, sc_core::sc_time& tim) {
		if (ph == tlm::BEGIN_REQ) {
			sc_assert(!m_reqSrc);
			m_reqSrc = &txn;
			ph = tlm::END_REQ;
			return tlm::TLM_UPDATED;
		} else if (ph == ocpip::BEGIN_DATA) {
			m_dataCountMap[&txn]++;
			ph = ocpip::END_DATA;
			return tlm::TLM_UPDATED;
		} else if (ph == tlm::END_RESP) {
			m_rspDst = 0;
		} else {
			sc_assert(false);
		}

		return tlm::TLM_ACCEPTED;
	}
	tlm::tlm_sync_enum nb_transport_bw(tlm::tlm_generic_payload& txn, tlm::tlm_phase& ph, sc_core::sc_time& tim) {
		if (ph == tlm::BEGIN_RESP) {
			sc_assert(!m_rspSrc);
			if (!m_rspInt) {
				m_rspInt = &txn;
				ph = tlm::END_RESP;
				return tlm::TLM_UPDATED;
			}
			m_rspSrc = &txn;
		} else if (ph == tlm::END_REQ) {
			m_reqDst = 0;
		} else {
			sc_assert(false);
		}

		return tlm::TLM_ACCEPTED;
	}

	void proc_fw() {
		if (m_reqSrc) {
			unsigned int threadId;

			ocpip::thread_id *threadIdExt;
			if ((m_policy != FIFO) && ocpTargetPort.get_extension(threadIdExt, *m_reqSrc))
				threadId = threadIdExt->value;
			else
				threadId = 0;	/* 0 if policy is FIFO, or the request doesn't have thread_id extension */

			m_reqQueues[threadId].push_back(m_reqSrc);
			if ((m_policy == ROUND_ROBIN) && (m_roundRobinIter == m_reqQueues.end())) {
				m_roundRobinIter = m_reqQueues.begin();
				m_roundRobinNextSlotHop = m_cycle + m_roundRobinSlotSize;
			}

			m_reqSrc = 0;
		}

		if (!m_reqDst) {
			static unsigned int writeDataRemain = 0;
			static tlm::tlm_generic_payload *writeDataTxn = 0;

			if (writeDataRemain) {
				if (m_dataCountMap[writeDataTxn]) {
					m_dataCountMap[writeDataTxn]--;
					writeDataRemain--;

					if (!writeDataRemain)
						m_dataCountMap.erase(writeDataTxn);

					/* make forward data transport for write transactions */
					tlm::tlm_phase phase = ocpip::BEGIN_DATA;
					sc_time time = SC_ZERO_TIME;
					ocpInitPort->nb_transport_fw(*writeDataTxn, phase, time);
				}
			} else {
				unsigned int picked = 0;
				if (m_policy == FIFO) {
					picked = 0;
				} else if (m_policy == PRIORITY) {
					for (QueueIter it = m_reqQueues.begin(); it != m_reqQueues.end(); it++) {
						picked = it->first;
						if (!it->second.empty()) {
							break;
						}
					}
				} else if (m_policy == ROUND_ROBIN) {
					if (m_roundRobinIter != m_reqQueues.end()) {
						if ((m_cycle < m_roundRobinNextSlotHop) && !m_roundRobinIter->second.empty()) {
							/* we haven't reached the cycle for next slot hop */
							picked = m_roundRobinIter->first;
						} else {
							/* hop to next slot */
							m_roundRobinIter++;
							if (m_roundRobinIter == m_reqQueues.end())
								m_roundRobinIter = m_reqQueues.begin();

							/* if the queue is empty, hop to next slot */
							QueueIter it = m_roundRobinIter;
							while (it->second.empty()) {
								it++;
								if (it == m_reqQueues.end())
									it = m_reqQueues.begin();

								if (it == m_roundRobinIter)
									break;
							}

							picked = it->first;;
							m_roundRobinIter = it;
							m_roundRobinNextSlotHop = m_cycle + m_roundRobinSlotSize;
						}
					}
				}

				if (m_reqQueues.count(picked) && !m_reqQueues[picked].empty()) {
					m_reqDst = m_reqQueues[picked].front();
					m_reqQueues[picked].pop_front();

					if (m_reqDst->get_command() == tlm::TLM_WRITE_COMMAND) {
						writeDataTxn = m_reqDst;
						writeDataRemain = calculateBurstLength(*writeDataTxn);
					}

					/* make a forward transport to the target (DRAM controller) */
					tlm::tlm_phase phase = tlm::BEGIN_REQ;
					sc_time time = SC_ZERO_TIME;
					if ((ocpInitPort->nb_transport_fw(*m_reqDst,phase, time) == tlm::TLM_UPDATED) && (phase == tlm::END_RESP))
						m_reqDst = 0;
				}
			}
		}
	}

	void proc_bw() {
		if (m_rspDst || (!m_rspSrc && !m_rspInt)) return;

		if (m_rspSrc) {
			m_rspInt = m_rspSrc;
			/* make a forward transport with END_RESP to the DRAM controller */
			tlm::tlm_phase phase = tlm::END_RESP;
			sc_time time = SC_ZERO_TIME;
			ocpInitPort->nb_transport_fw(*m_rspSrc, phase, time);
		}

		if (m_rspInt) {
			m_rspDst = m_rspInt;
			m_rspInt = m_rspSrc;
			m_rspSrc = 0;
		} else {
			m_rspDst = m_rspSrc;
			m_rspSrc = 0;
		}

		/* make a backward transport with BEGIN_RESP to the initiator */
		tlm::tlm_phase phase = tlm::BEGIN_RESP;
		sc_time time = SC_ZERO_TIME;
		if ((ocpTargetPort->nb_transport_bw(*m_rspDst, phase, time) == tlm::TLM_UPDATED) && (phase == tlm::END_RESP))
			m_rspDst = 0;
	}

	void countCycle() {
		while (true) {
			wait(SC_ZERO_TIME);
			m_cycle++;

			wait();
		}
	}

	unsigned int calculateBurstLength(tlm::tlm_generic_payload &txn) {
		ocpip::burst_length *bLen;
		if (ocpip::extension_api::get_extension<ocpip::burst_length>(bLen, txn))
			return bLen->value;
		else {
			SC_REPORT_WARNING(SC_ID_WITHOUT_MESSAGE_, "burst_length extension is not used. Calculating burst length from data length.");
			return txn.get_data_length() / sizeof(Td);
		}
	}

private:
	/* state variables */
	/* requests and responses */
	tlm::tlm_generic_payload *m_reqSrc, *m_reqDst, *m_rspSrc, *m_rspInt, *m_rspDst;

	std::map<tlm::tlm_generic_payload *, int> m_dataCountMap;
	std::map<unsigned int, std::deque<tlm::tlm_generic_payload *> > m_reqQueues;
	scheduling_policy m_policy;

	/* current simulation cycle */
	uint64 m_cycle;

	/* round-robin related state variables */
	QueueIter m_roundRobinIter;
	uint64 m_roundRobinSlotSize, m_roundRobinNextSlotHop;
};

#endif /* ADM_SCHEDULER_H_ */
