/*
 * Copyright (c) 2012-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * This software is available to you under a choice of one of two
 * licenses.  You may choose to be licensed under the terms of the GNU
 * General Public License (GPL) Version 2, available from the file
 * COPYING in the main directory of this source tree, or the
 * OpenIB.org BSD license below:
 *
 *     Redistribution and use in source and binary forms, with or
 *     without modification, are permitted provided that the following
 *     conditions are met:
 *
 *      - Redistributions of source code must retain the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer.
 *
 *      - Redistributions in binary form must reproduce the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer in the documentation and/or other materials
 *        provided with the distribution.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 *
 */

#ifndef AGG_FABRIC_GRAPH_H_
#define AGG_FABRIC_GRAPH_H_

#include <sys/types.h>
#include <cstdint>
#include <functional>
#include <memory>
#include <unordered_map>

#include "agg_ib_types.h"
#include "agg_types.h"
#include "am_common.h"
#include "amkey_manager.h"
#include "an_config_manager.h"
#include "fabric_graph_data.h"
#include "fabric_graph_update.h"
#include "fabric_update.h"
#include "option_manager.h"
#include "port_data.h"

class Vnode;
class Vport;

using ListPortPtr = std::list<class Port*>;

using VectorOfPorts = std::vector<Port*>;
using VectorsOfListOfNodes = std::vector<ListOfNodes>;
using SetPortsPtr = std::set<class Port*>;
using SetOfGuids = std::set<uint64_t>;

using MapGuidToVnodePtr = std::map<uint64_t, std::unique_ptr<Vnode>>;
using MapStrToListOfNodes = std::map<string, ListOfNodes>;
using MapGuidToHCCoordinates = std::map<uint64_t, uint16_t>;
using MapGuidToVportPtr = std::map<uint64_t, std::unique_ptr<Vport>>;
using VectorOfNodes = std::vector<class Node*>;
using VectorOfVports = std::vector<class Vport*>;

// Used for hash set of invalid ports
using PairNodeGuidPortNum = std::pair<u_int64_t, phys_port_t>;
using HashMapNodeGuidPortNumToStr = std::unordered_map<PairNodeGuidPortNum, char const* const, pair_hash>;

class FabricDbException : public std::exception
{
};

class FabricGraph
{
   protected:
    FabricProvider m_fabric_provider_;
    AMKeyManager m_amkey_manager_;
    CommandManager* m_command_manager_ptr_;

    // Data structures used during smdb/virt file parsing
    MapPortKeyToPairNodeKeyAndPortIndex m_ca_port_key_to_node_port_key_and_index_;   // map of all fabric ca ports
    MapGuidToNodePtr m_node_by_guid_;                                                // Provides the node by guid
    MapGuidToVnodePtr m_vnode_by_guid_;                                              // Provides the vnode by guid
    VectorsOfListOfNodes m_node_by_rank_;                                            // Provides the node by node rank (rank 0 is root)
    MapGuidToVportPtr m_vport_by_guid_;                                              // set of all fabric vports
    VectorOfPorts m_port_by_lid_;                                                    // Pointer to the Port by its lid
    uint8_t m_max_rank_;
    lid_t m_max_lid_;   // Track max lid used.
    SetNodesPtr m_root_nodes_;
    SetNodesPtr m_sw_nodes_;   // Switch nodes.
    std::set<uint16_t> m_coordinates_set;
    uint16_t m_max_dfp_group_;
    uint64_t m_subnet_prefix_;
    SetOfGuids m_ignore_host_guids_;

    uint64_t m_epoch_;   // Serial number of smdb updates
    uint64_t m_vport_epoch_;
    bool m_are_vports_inconsistent_with_physical_ports_;

    // In case AM recovered(AMKey), need to rediscover the fabric regardless of smdb file change.
    bool m_rediscover_required_;

    MapPortKeyToAnPortPtr m_an_port_by_key_;   // AggNode ptr node by node key
    ListAggPathPtr m_paths_;                   // list of all paths in the fabric

    FabricTopologyData m_topology_data_;
    port_key_t m_sm_port_guid_;

    MapAnToAnInfo m_map_an_to_an_info_;   // Routing information of single
                                          // hop path

    // temp update DB
    ListPortDataUpdate m_ports_data_update_;
    ListPathUpdate m_paths_update_;
    port_key_t m_sm_port_guid_update_;
    // SetPortDataUpdate       m_delayed_port_data_updates_;
    MapPortKeyToAnPortPtr m_mad_send_retry_;
    bool m_startup_update_fabric_state_;
    HashMapNodeGuidPortNumToStr m_invalid_ports_hash_;

    bool m_job_handling_started_;
    // control_path_version (IB: active am class version) to be set as  on all ANs
    uint8_t m_control_path_version_;
    uint16_t m_min_tree_table_size_;
    u_int16_t m_data_path_version_;

    sharp_job_id_t m_max_jobs_number_;

    // device configuration manager
    AnConfigManager m_an_config_manager_;

   public:
    // Constructor
    FabricGraph(CommandManager* command_manager_ptr);

    // FabricGraph is a created and owned by FilesParserManager, it should never be copied
    FabricGraph(const FabricGraph&) = delete;
    FabricGraph& operator=(const FabricGraph&) = delete;

    // Processes states of AggNodes
    void ProcessAggNodesStates(const MapPortKeyToAnPortPtr& p_ports);

    // Get Topology type from one location in order to support
    // setting type different from type configured, if required.
    const FabricTopologyInfo& GetFabricTopologyInfo() const { return m_topology_data_.GetTopologyInfo(); }

    TopologyType GetTopologyType() const { return m_topology_data_.GetTopologyInfo().m_topology_type; }

    CommandManager& GetCommandManager() const { return *m_command_manager_ptr_; }

    void ResizeTopologyData();

    // Return adjacent switch nodes of a switch
    // const SetNodesPtr &(Node *p_node) const;

    int Init();

    inline bool IsRediscoverRequired() { return m_rediscover_required_; }
    inline void clearRediscoverRequired() { m_rediscover_required_ = false; }

    void StartCommandHandling();
    bool HandleAggNodesInitState(bool seamless_restart);
    void CompareAggNodePortConfigWithConf();

    inline void DumpAMKeysToFile() { m_amkey_manager_.DumpAMKeysToFile(); }

    // Validates that AM is running in SM port
    void ValidateLocalPort();
    int UpdateFabricStart();
    void LogInvalidPorts();
    int UpdateFabricEnd();
    int UpdateFabricFailed();
    void RevertFabricEpoch();

    int UpdateVportStart();
    int UpdateVportEnd();
    int UpdateVportFailed();
    void RevertVportEpoch();

    // Retry sending mad that received temporary error on the current epoch
    void MadSendRetry();

    // return MAX_NUM_HOPS if no path found
    uint8_t GetNumHops(uint64_t from_sw_guid, uint64_t to_sw_guid);

    // Add a link into the fabric - this will create nodes / ports and link between them
    // by calling the forward methods MakeNode + MakeLinkBetweenPorts
    int AddLink(const string& type1,
                phys_port_t num_ports_1,
                uint64_t node_guid_1,
                uint64_t port_guid_1,
                string& desc1,
                lid_t lid1,
                uint8_t lmc1,
                phys_port_t port_num_1,
                const string& type2,
                phys_port_t num_ports_2,
                uint64_t node_guid_2,
                uint64_t port_guid_2,
                string& desc2,
                lid_t lid2,
                uint8_t lmc2,
                phys_port_t port_num_2);

    uint32_t GetNodesNumber() const { return (uint32_t)m_node_by_guid_.size(); }

    uint32_t GetCaPortsNumber() const { return (uint32_t)m_ca_port_key_to_node_port_key_and_index_.size(); }
    std::size_t GetNumberOfPhysicalPorts() const
    {
        std::size_t number_of_ports = 0;
        for (const auto& current_node_pair : m_node_by_guid_) {
            number_of_ports += current_node_pair.second->GetNumberOfAllocatedPorts();
        }
        return number_of_ports;
    }

    int AssignNodesRank(SetOfGuids& root_guids);
    int AssignNodesHyperCubeCoordinates(MapGuidToHCCoordinates& coordinates_map);

    bool SetRetryOnMadFailure(int rec_status, Port* p_port, NodeState state);

    void SetAggNodeActiveState(Port* p_port);
    void HandleFabricUpdates();

    ////////////////////////////////////////////
    /// Fabric CSV Parser Call Backs Function
    ////////////////////////////////////////////
    int CreateNode(const NodeRecord& node_record);
    void UpdateEpochForNodes();
    int CreatePort(const PortRecord& port_record);
    void MarkPortForDeletion(const PortRecord& port_record, Node* p_node, char const* const reason);
    void DisablePort(const u_int64_t node_guid, const phys_port_t port_num);
    int CreateLink(const LinkRecord& link_record);
    int CreateSwitchTopoTree(const SwitchRecord& switch_record);
    int AssignNodeHyperCubeCoordinate(const SwitchRecord& switch_record);
    int AssignNodeDfpGroupInfo(const SwitchRecord& switch_record);
    int UpdateAnToAnRouting(const AnToAnRecord& an_to_an_record);
    int ParseSmRecord(const SmRecord& sm_record);
    int ParseSmPortsRecord(const SmPortsRecord& sm_ports_record);
    int ParseSmsRecord(const SmsRecord& sms_record);
    TopologyType TopologyStrToType(const std::string& topology_str);

    //////////////////////////////////////////////////
    /// Fabric virtual CSV Parser Call Backs Function
    //////////////////////////////////////////////////
    int CreateVnode(const VnodeRecord& vnode_record);
    int CreateVport(const VportRecord& vport_record);

    ////////////////////////////////////////////
    /// Fabric Provider Call Backs Function
    ////////////////////////////////////////////

    void SetAMKeyCallback(FabricProviderCallbackContext* p_context, int rec_status, void* p_data);

    void RecoverAMKeyCallback(FabricProviderCallbackContext* p_context, int rec_status, void* p_data);

    void DiscoverAggNodeCallback(FabricProviderCallbackContext* p_context, int rec_status, void* p_data);

    void RediscoverAggNodeCallback(FabricProviderCallbackContext* p_context, int rec_status, void* p_data);

    void CleanAggNodeCallback(FabricProviderCallbackContext* p_context, int rec_status, void* p_data);

    void ConfigureAggNodeCallback(FabricProviderCallbackContext* p_context, int rec_status, void* p_data);

   private:
    int CalculateNodesRank();

    void CalculateMinHopsTables();

    Node* GetNodeByGuid(const uint64_t node_guid);
    Vnode* GetVnodeByGuid(const uint64_t vnode_guid);
    Port* GetCaPortByPortKey(const hca_port_key_t port_key);
    Vport* GetVportByGuid(const uint64_t vport_guid);

    // create a new node in fabric (don't check if exists already)
    Node* MakeNode(const NodeType type, const phys_port_t num_ports, const uint64_t node_guid, const string& node_description);

    Vnode* MakeVnode(const uint64_t vnode_guid, const string& vnode_description);

    Vport* MakeVport(const uint64_t vport_guid,
                     const lid_t vlid,
                     const phys_port_t vport_num,
                     Vnode* p_vnode,
                     Port* p_port,
                     const uint8_t is_lid_required);

    // set the node's port given data (create one of does not exist).
    Port* SetNodePort(Node* p_node,
                      const uint64_t port_guid,
                      const lid_t lid,
                      const uint8_t lmc,
                      const phys_port_t port_number,
                      const PortTimestamp& timestamp,
                      const SpecialPortType special_port_type,
                      const uint16_t port_rate,
                      const string& port_label,
                      const uint16_t aport,
                      const uint8_t plane_number,
                      const uint8_t number_of_planes);

    // Add a link between the given ports.
    // not creating sys ports for now.
    int MakeLinkBetweenPorts(Port* p_port1, Port* p_port2);

    // set a lid port
    // void SetLidPort(lid_t lid, Port *p_port);

    // get a port by lid
    Port* GetPortByLid(lid_t lid)
    {
        if (m_port_by_lid_.empty() || (m_port_by_lid_.size() < (unsigned)lid + 1))
            return NULL;
        return (m_port_by_lid_[lid]);
    };

    // dump out the contents of the entire fabric
    int Dump(ostream& sout) const;
    int Dump() const;

    int AddRootNode(const uint64_t root_guid);

    void RemoveLidPort(Port* p_port);
    void SetLidPort(lid_t lid, Port* p_port);

    void CleanAggNode(Port* p_port);   // Send a clean-all MAD to AN

    void SetAMKeys();
    void SetAMKey(Port* p_port);

    // Configure all required AggNodes.
    void ConfigureAggNode(Port* p_port);
    void ConfigureAggNodePorts(Port* p_port, AggregatedLogMessage& configured_ports_msg);

    void UpdateTopologyData();

    void CalculateAggNodeGraph();
    void CalculateMinhopAggNodeGraph();
    void CalculateTreeAggNodeGraph();
    void CalculateAnMinHopsTables();
    void CalculateDfpAnMinHopsTables();

    void BuildCaUpdateList();
    void BuildVportsUpdateList();
    void BuildPathsUpdateList();

    void CreateCaPortData(Port* p_port, const PortInfo& port_info);
    void CreateVPortData(Vport* p_vport, const PortInfo& vport_info);
    int DiscoverAggNode(Port* p_port);
    int RecoverAMKey(Port* p_port);
    void AddUpdateAn(PortData* port_data, FabricUpdateType update_type);
    int AddComputePort(PortData* p_port_data, const string& host_name);

    void CheckIsCaPortUpdate(Port* p_port, const PortInfo& port_info, AggregatedLogMessage& unsupported_change_msgs);
    int RediscoverAggNode(Port* p_port);
    void CreateVportUpdateIfNeeded(Vport* p_vport,
                                   const PortInfo& new_vport_info,
                                   std::vector<uint64_t>& old_vports_guids,
                                   AggregatedLogMessage& disabled_vports_guids,
                                   AggregatedLogMessage& enabled_vports_guids);

    static void GetHostName(const PortInfo& port_info, string& hca_host_name);

    bool IsValidLid(lid_t lid) const { return (lid && lid <= FABRIC_MAX_VALID_LID); };

    void SetDataPathVersion();

    static node_min_hop_key_t GetNodeMinHopKey(Port* port);
};

#endif   // AGG_FABRIC_GRAPH_H_
