/**************************************************************************** * Copyright (c) 2024 PX4 Development Team. * SPDX-License-Identifier: BSD-3-Clause ****************************************************************************/ #pragma once #include "util.h" #include #include #include #include #include #include #include #include #include // This implements a directed graph with potential cycles used for translation. // There are 2 types of nodes: messages (e.g. publication/subscription endpoints) and // translations. Translation nodes are always in between message nodes, and can have N input messages // and M output messages. struct MessageIdentifier { std::string topic_name; MessageVersionType version; bool operator==(const MessageIdentifier& other) const { return topic_name == other.topic_name && version == other.version; } bool operator!=(const MessageIdentifier& other) const { return !(*this == other); } }; template<> struct std::hash { std::size_t operator()(const MessageIdentifier& s) const noexcept { std::size_t h1 = std::hash{}(s.topic_name); std::size_t h2 = std::hash{}(s.version); return h1 ^ (h2 << 1); } }; using MessageBuffer = std::shared_ptr; template class MessageNode; template class Graph; template using MessageNodePtrT = std::shared_ptr>; template class TranslationNode { public: using TranslationCB = std::function&, std::vector&)>; TranslationNode(std::vector> inputs, std::vector> outputs, TranslationCB translation_db) : _inputs(std::move(inputs)), _outputs(std::move(outputs)), _translation_cb(std::move(translation_db)) { assert(_inputs.size() <= kMaxNumInputs); _input_buffers.resize(_inputs.size()); for (unsigned i = 0; i < _inputs.size(); ++i) { _input_buffers[i] = _inputs[i]->buffer(); } _output_buffers.resize(_outputs.size()); for (unsigned i = 0; i < _outputs.size(); ++i) { _output_buffers[i] = _outputs[i]->buffer(); } } void setInputReady(unsigned index) { _inputs_ready.set(index); } bool translate() { if (_inputs_ready.count() == _input_buffers.size()) { _translation_cb(_input_buffers, _output_buffers); _inputs_ready.reset(); return true; } return false; } const std::vector>& inputs() const { return _inputs; } const std::vector>& outputs() const { return _outputs; } private: static constexpr int kMaxNumInputs = 32; const std::vector> _inputs; std::vector _input_buffers; ///< Cached buffers from _inputs.buffer() const std::vector> _outputs; std::vector _output_buffers; const TranslationCB _translation_cb; std::bitset _inputs_ready; }; template using TranslationNodePtrT = std::shared_ptr>; template class MessageNode { public: explicit MessageNode(NodeData node_data, size_t index, MessageBuffer message_buffer) : _buffer(std::move(message_buffer)), _data(std::move(node_data)), _index(index) {} MessageBuffer& buffer() { return _buffer; } void addTranslationInput(TranslationNodePtrT node, unsigned input_index) { _translations.push_back(Translation{std::move(node), input_index}); } NodeData& data() { return _data; } void resetNodes() { _translations.clear(); } private: struct Translation { TranslationNodePtrT node; ///< Counterpart to the TranslationNode::_inputs unsigned input_index; ///< Index into the TranslationNode::_inputs }; MessageBuffer _buffer; std::vector _translations; NodeData _data; const size_t _index; friend class Graph; }; template class Graph { public: using MessageNodePtr = MessageNodePtrT; ~Graph() { // Explicitly reset the nodes array to break up potential cycles and prevent memory leaks for (auto& [id, node] : _nodes) { node->resetNodes(); } } /** * @brief Add a message node if it does not exist already */ bool addNodeIfNotExists(const IdType& id, NodeData node_data, const MessageBuffer& message_buffer) { if (_nodes.find(id) != _nodes.end()) { return false; } // Node that we cannot remove nodes due to using the index as an array index const size_t index = _nodes.size(); _nodes.insert({id, std::make_shared>(std::move(node_data), index, message_buffer)}); return true; } /** * @brief Add a translation edge with N inputs and M output nodes. All nodes must already exist. */ void addTranslation(const typename TranslationNode::TranslationCB& translation_cb, const std::vector& inputs, const std::vector& outputs) { auto init = [this](const std::vector& from, std::vector>& to) { for (unsigned i=0; i < from.size(); ++i) { auto node_iter = _nodes.find(from[i]); assert(node_iter != _nodes.end()); to[i] = node_iter->second; } }; std::vector> input_nodes(inputs.size()); init(inputs, input_nodes); std::vector> output_nodes(outputs.size()); init(outputs, output_nodes); auto translation_node = std::make_shared>(std::move(input_nodes), std::move(output_nodes), translation_cb); for (unsigned i=0; i < translation_node->inputs().size(); ++i) { translation_node->inputs()[i]->addTranslationInput(translation_node, i); } } /** * @brief Translate a message node in the graph. * * @param node The message node to translate. * @param on_translated A callback function that is called for translated nodes (with an updated message buffer). * This will not be called for the provided node. */ void translate(const MessageNodePtr& node, const std::function& on_translated) { resetNodesVisited(); // Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm, // while using translation nodes as barriers (only continue when all inputs are ready) std::queue queue; _node_visited[node->_index] = true; queue.push(node); while (!queue.empty()) { MessageNodePtr current = queue.front(); queue.pop(); for (auto& translation : current->_translations) { const bool any_output_visited = std::any_of(translation.node->outputs().begin(), translation.node->outputs().end(), [&](const MessageNodePtr& next_node) { return _node_visited[next_node->_index]; }); // If any output node has already been visited, skip this translation node (prevents translating // backwards, from where we came from already) if (any_output_visited) { continue; } translation.node->setInputReady(translation.input_index); // Iterate the output nodes only if the translation node is ready if (translation.node->translate()) { for (auto &next_node : translation.node->outputs()) { if (_node_visited[next_node->_index]) { continue; } _node_visited[next_node->_index] = true; on_translated(next_node); queue.push(next_node); } } } } } std::optional findNode(const IdType& id) const { auto iter = _nodes.find(id); if (iter == _nodes.end()) { return std::nullopt; } return iter->second; } void iterateNodes(const std::function& cb) const { for (const auto& [id, node] : _nodes) { cb(id, node); } } /** * Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm */ void iterateBFS(const MessageNodePtr& node, const std::function& cb) { resetNodesVisited(); std::queue queue; _node_visited[node->_index] = true; queue.push(node); cb(node); while (!queue.empty()) { MessageNodePtr current = queue.front(); queue.pop(); for (auto& translation : current->_translations) { for (auto& next_node : translation.node->outputs()) { if (_node_visited[next_node->_index]) { continue; } _node_visited[next_node->_index] = true; queue.push(next_node); cb(next_node); } } } } private: void resetNodesVisited() { _node_visited.resize(_nodes.size()); std::fill(_node_visited.begin(), _node_visited.end(), false); } std::unordered_map _nodes; std::vector _node_visited; ///< Cached, to avoid the need to re-allocate on each iteration };