//----------------------
// Three methods in one?
//----------------------

void Server::change_role(ServerRole new_role)
{
    switch (new_role) {
    case ServerRole::FOLLOWER: {
        dout << "Switching to follower" << dendl;

        int leader_id = 0;
        if (follower_state != nullptr) {
            leader_id = follower_state->leader_id;
            delete follower_state;
        }

        follower_state = new FollowerState(leader_id);
        timer.start(utils::generate_timeout(config->election_timeout));
        break;
    }
    case ServerRole::CANDIDATE: {
        dout << "switching to candidate" << dendl;
        if (candidate_state != nullptr) {
            delete candidate_state;
        }

        candidate_state = new CandidateState();
        candidate_state->votes[id] = true;
        current_term++;
        voted_for = id;
        save_state();
        timer.start(utils::generate_timeout(config->election_timeout));
        broadcast_request_vote_req();
        break;
    }
    case ServerRole::LEADER: {
        dout << "Switching to leader" << dendl;
        if (leader_state != nullptr) {
            delete leader_state;
        }

        leader_state = new LeaderState(
            config->num_servers(),
            log->size());
        broadcast_append_entries_req();
        timer.start(config->heartbeat_interval);
        break;
    }
    }
    role = new_role;
}

//---------------------------
// Code Sample A
//---------------------------

void RaftServer::handleAppendEntries(
    const messages::AppendEntries& message,
    RPC_ID id
) {
    // Reset timer to avoid triggering new election.
    timer.reset();

    messages::RaftRPC response;
    messages::AppendEntriesResp* appendEntriesResp = 
        response.mutable_append_entries_resp();
    appendEntriesResp->set_term(myTerm);
    appendEntriesResp->set_responder_network_address(
        myNetworkAddress.toString());
    
    // Reject lower term messages.
    if (message.term() < myTerm) {
        appendEntriesResp->set_success(false); 
        try {
            raftNetwork.sendResponse(response, id);
        } catch (NetworkException e) {
            printf("Failed to send failed append entries. %s", e.what());
        }
        return;
    }

    // Become a follower upon observing a higher term.
    if (message.term() > myTerm) {
        updateTermAndVote(message.term(), "");
        changeServerState(FOLLOWER);
    }

    // If we are candidate with the same term as leader, become a follower
    // but do not erase our vote for ourself.
    if (message.term() == myTerm && state == CANDIDATE) {
        changeServerState(FOLLOWER);
    }

    // If my log doesn't contain entry at leader's prev log index that matches
    // our log, and leader isn't trying to append to beginning of our log.
    messages::LogEntry logEntry;
    if (message.leader_prev_log_index() != 0 && 
        (!logs.getEntry(logEntry, message.leader_prev_log_index()) || 
         logEntry.term() != message.leader_prev_log_term())) {
            appendEntriesResp->set_success(false);
            try {
                raftNetwork.sendResponse(response, id);
            } catch (NetworkException e) {
                printf("Failed to send failed append entries. %s", e.what());
            }
            return;
    }


    // Check if existing entries conflict with new entries. 
    int appendFrom = -1;
    unsigned int logIndex;
    for (unsigned int entriesIndex = 0; 
        entriesIndex < message.entries().size(); 
        entriesIndex++) {
            logIndex = message.leader_prev_log_index() + 1 + entriesIndex; 
            
            // If we don't have an entry at that log index, append to my log 
            // starting at that log index.
            if (!logs.getEntry(logEntry, logIndex)) {
                appendFrom = entriesIndex;
                break;
            }
            
            // Truncate log on conflicts.
            if (logEntry.term() != message.entries().at(entriesIndex).term()) {
                logs.truncateTo(logIndex - 1);
                appendFrom = entriesIndex;
                break;
            }
    }

    // Append any new entries not already in my log.
    if (appendFrom != -1) {
        for (unsigned int entriesIndex = appendFrom; 
            entriesIndex < message.entries().size();
            entriesIndex++) {
                logs.append(message.entries().at(entriesIndex));
        }
    }

    // Update last committed based on leader's commit index and index of 
    // my last new entry.
    if (message.leader_commit_index() > lastCommitted) {
        unsigned int oldLastCommitted = lastCommitted;
        lastCommitted = std::min(
            message.leader_commit_index(), logs.getNumLogs());
        // Push all newly committed log entries onto queue to be applied to
        // application state machine.
        messages::LogEntry entry;
        for (unsigned int i = oldLastCommitted + 1; i <= lastCommitted; i++) {
            logs.getEntry(entry, i);
            Command newCommand(entry, -1, i);
            commandQueue.push(newCommand);
        }
    }

    // Respond to leader indicating success.
    appendEntriesResp->set_success(true);
    appendEntriesResp->set_responder_last_log_index(logs.getNumLogs());
    try {
        raftNetwork.sendResponse(response, id);
    } catch (NetworkException e) {
        printf("%s", e.what());
    }
}

//---------------------------
// Code Sample B
//---------------------------

void Server::handle_append_entries_req(RaftMessage* msg, NetworkAddress sender_addr)
{
    dout << "received AppendEntriesRequest from " << sender_addr << dendl;
    const AppendEntriesRequest& req = msg->append_entries_request();
    AppendEntriesResponse* response = new AppendEntriesResponse();
    if (msg->sender_term() < current_term or role == ServerRole::LEADER) {
        response->set_success(false);
    } else {
        dout << "processing AppendEntriesRequest" << dendl;
        change_role(ServerRole::FOLLOWER);
        follower_state->leader_id = req.leader_id();

        // ensure the previous log entry is not a conflict
        int prev_log_index = req.prev_log_index();
        LogEntry* prev_log_entry = log->at(prev_log_index);
        if (prev_log_index >= 0 and (prev_log_entry == nullptr or prev_log_entry->term() != req.prev_log_term())) {
            response->set_success(false);
        } else {
            response->set_success(true);
            size_t i = 0;
            dout << "Passed checks, handling entries[] array of size " << req.log_entries_size() << dendl;

            // loop through log entries to find the truncation point
            for (; i < req.log_entries_size(); i++) {
                int curr_log_index = prev_log_index + i + 1;
                LogEntry* curr_log_entry = log->at(curr_log_index);
                if (curr_log_entry == nullptr or curr_log_entry->term() != req.log_entries(i).term()) {
                    log->truncate(curr_log_index);
                    dout << "Truncating log to size " << curr_log_index << dendl;
                    break;
                }
            }

            // append the remaining entries
            vector<LogEntry*> to_append {};
            for (; i < req.log_entries_size(); i++) {
                LogEntry* entry_copy = new LogEntry();
                entry_copy->set_term(req.log_entries(i).term());
                entry_copy->set_command(req.log_entries(i).command());

                to_append.push_back(entry_copy);
            }
            dout << "Appending " << to_append.size() << " entries to the log" << dendl;
            log->append(to_append);

            dout << "My commit index = " << commit_index << ", leader's commit index: " << req.leader_commit() << dendl;
            // update commit_index if needed
            if (req.leader_commit() > commit_index) {
                commit_index = min((int)req.leader_commit(), prev_log_index + req.log_entries_size());
            }

            // apply commands that have been committed
            while (last_applied < commit_index) {
                last_applied++;
                dout << "Applying command with index " << last_applied << dendl;
                LogEntry* entry_to_apply = log->at(last_applied);
                if (entry_to_apply == nullptr) {
                    dout << "error: committed entry not found in log" << dendl;
                } else {
                    state_machine->run_command(entry_to_apply->command());
                }
                delete entry_to_apply;
            }
            response->set_next_index(log->size());
            response->set_has_next_index(true);
        }
    }

    RaftMessage* resp = new RaftMessage();
    resp->set_allocated_append_entries_response(response);
    resp->set_sender_term(current_term);
    resp->set_has_sender_term(true);
    send_raft_msg(resp, sender_addr);
}

//---------------------------------
// Duplication in exception classes
//---------------------------------

struct LogException : public std::exception {
    public:
        LogException(const std::string& msg) : _msg(msg){}

        virtual const char* what() const noexcept override {
            return _msg.c_str();
        }
    private:
        std::string _msg;
};

struct QueueException : public std::exception {
    public:
        QueueException(const std::string& msg) : _msg(msg){}

        virtual const char* what() const noexcept override
        {
            return _msg.c_str();
        }
    private:
        std::string _msg;
};

//---------------------------------------
// Better approach to extended exceptions
//---------------------------------------

struct ExceptionWithMsg : public std::exception {
    public:
        ExceptionWithMsg(const std::string& msg) : _msg(msg){}

        virtual const char* what() const noexcept override {
            return _msg.c_str();
        }
    private:
        std::string _msg;
};


struct LogException : public ExceptionWithMsg {};
struct QueueException : public ExceptionWithMsg {};

//------------------------------------------------
// Even easier to use (should also eliminate the
// fixed-size buffer restriction)
//------------------------------------------------

struct ExceptionWithMsg : public std::exception {
    public:
        ExceptionWithMsg(const char *format, ...) : _msg()
        {
            char buf[1000];
            va_list ap;
            va_start(ap, format);
            vsnprintf(buf, sizeof(buf), format, ap);
            _msg.assign(buf);
        }

        virtual const char* what() const noexcept override {
            return _msg.c_str();
        }
    private:
        std::string _msg;
};