msg: add message translation node for ROS

This commit is contained in:
Beat Küng
2024-10-29 14:50:08 +01:00
committed by Silvan Fuhrer
parent 975ec30c9c
commit f6bfa9812e
32 changed files with 3200 additions and 0 deletions
+293
View File
@@ -0,0 +1,293 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#pragma once
#include "util.h"
#include <algorithm>
#include <bitset>
#include <functional>
#include <memory>
#include <optional>
#include <queue>
#include <string>
#include <utility>
#include <vector>
// This implements a directed graph with potential cycles used for translation.
// There are 2 types of nodes: messages (e.g. publication/subscription endpoints) and
// translations. Translation nodes are always in between message nodes, and can have N input messages
// and M output messages.
struct MessageIdentifier {
std::string topic_name;
MessageVersionType version;
bool operator==(const MessageIdentifier& other) const {
return topic_name == other.topic_name && version == other.version;
}
bool operator!=(const MessageIdentifier& other) const {
return !(*this == other);
}
};
template<>
struct std::hash<MessageIdentifier>
{
std::size_t operator()(const MessageIdentifier& s) const noexcept
{
std::size_t h1 = std::hash<std::string>{}(s.topic_name);
std::size_t h2 = std::hash<std::uint32_t>{}(s.version);
return h1 ^ (h2 << 1);
}
};
using MessageBuffer = std::shared_ptr<void>;
template <typename NodeData, typename IdType>
class MessageNode;
template <typename NodeData, typename IdType>
class Graph;
template <typename NodeData, typename IdType>
using MessageNodePtrT = std::shared_ptr<MessageNode<NodeData, MessageIdentifier>>;
template <typename NodeData, typename IdType=MessageIdentifier>
class TranslationNode {
public:
using TranslationCB = std::function<void(const std::vector<MessageBuffer>&, std::vector<MessageBuffer>&)>;
TranslationNode(std::vector<MessageNodePtrT<NodeData, IdType>> inputs,
std::vector<MessageNodePtrT<NodeData, IdType>> outputs,
TranslationCB translation_db)
: _inputs(std::move(inputs)), _outputs(std::move(outputs)), _translation_cb(std::move(translation_db)) {
assert(_inputs.size() <= kMaxNumInputs);
_input_buffers.resize(_inputs.size());
for (unsigned i = 0; i < _inputs.size(); ++i) {
_input_buffers[i] = _inputs[i]->buffer();
}
_output_buffers.resize(_outputs.size());
for (unsigned i = 0; i < _outputs.size(); ++i) {
_output_buffers[i] = _outputs[i]->buffer();
}
}
void setInputReady(unsigned index) {
_inputs_ready.set(index);
}
bool translate() {
if (_inputs_ready.count() == _input_buffers.size()) {
_translation_cb(_input_buffers, _output_buffers);
_inputs_ready.reset();
return true;
}
return false;
}
const std::vector<MessageNodePtrT<NodeData, IdType>>& inputs() const { return _inputs; }
const std::vector<MessageNodePtrT<NodeData, IdType>>& outputs() const { return _outputs; }
private:
static constexpr int kMaxNumInputs = 32;
const std::vector<MessageNodePtrT<NodeData, IdType>> _inputs;
std::vector<MessageBuffer> _input_buffers; ///< Cached buffers from _inputs.buffer()
const std::vector<MessageNodePtrT<NodeData, IdType>> _outputs;
std::vector<MessageBuffer> _output_buffers;
const TranslationCB _translation_cb;
std::bitset<kMaxNumInputs> _inputs_ready;
};
template <typename NodeData, typename IdType>
using TranslationNodePtrT = std::shared_ptr<TranslationNode<NodeData, MessageIdentifier>>;
template <typename NodeData, typename IdType=MessageIdentifier>
class MessageNode {
public:
explicit MessageNode(NodeData node_data, size_t index, MessageBuffer message_buffer)
: _buffer(std::move(message_buffer)), _data(std::move(node_data)), _index(index) {}
MessageBuffer& buffer() { return _buffer; }
void addTranslationInput(TranslationNodePtrT<NodeData, IdType> node, unsigned input_index) {
_translations.push_back(Translation{std::move(node), input_index});
}
NodeData& data() { return _data; }
void resetNodes() {
_translations.clear();
}
private:
struct Translation {
TranslationNodePtrT<NodeData, IdType> node; ///< Counterpart to the TranslationNode::_inputs
unsigned input_index; ///< Index into the TranslationNode::_inputs
};
MessageBuffer _buffer;
std::vector<Translation> _translations;
NodeData _data;
const size_t _index;
friend class Graph<NodeData, IdType>;
};
template <typename NodeData, typename IdType=MessageIdentifier>
class Graph {
public:
using MessageNodePtr = MessageNodePtrT<NodeData, IdType>;
~Graph() {
// Explicitly reset the nodes array to break up potential cycles and prevent memory leaks
for (auto& [id, node] : _nodes) {
node->resetNodes();
}
}
/**
* @brief Add a message node if it does not exist already
*/
bool addNodeIfNotExists(const IdType& id, NodeData node_data, const MessageBuffer& message_buffer) {
if (_nodes.find(id) != _nodes.end()) {
return false;
}
// Node that we cannot remove nodes due to using the index as an array index
const size_t index = _nodes.size();
_nodes.insert({id, std::make_shared<MessageNode<NodeData, IdType>>(std::move(node_data), index, message_buffer)});
return true;
}
/**
* @brief Add a translation edge with N inputs and M output nodes. All nodes must already exist.
*/
void addTranslation(const typename TranslationNode<NodeData, IdType>::TranslationCB& translation_cb,
const std::vector<IdType>& inputs, const std::vector<IdType>& outputs) {
auto init = [this](const std::vector<IdType>& from, std::vector<MessageNodePtrT<NodeData, IdType>>& to) {
for (unsigned i=0; i < from.size(); ++i) {
auto node_iter = _nodes.find(from[i]);
assert(node_iter != _nodes.end());
to[i] = node_iter->second;
}
};
std::vector<MessageNodePtrT<NodeData, IdType>> input_nodes(inputs.size());
init(inputs, input_nodes);
std::vector<MessageNodePtrT<NodeData, IdType>> output_nodes(outputs.size());
init(outputs, output_nodes);
auto translation_node = std::make_shared<TranslationNode<NodeData, IdType>>(std::move(input_nodes), std::move(output_nodes), translation_cb);
for (unsigned i=0; i < translation_node->inputs().size(); ++i) {
translation_node->inputs()[i]->addTranslationInput(translation_node, i);
}
}
/**
* @brief Translate a message node in the graph.
*
* @param node The message node to translate.
* @param on_translated A callback function that is called for translated nodes (with an updated message buffer).
* This will not be called for the provided node.
*/
void translate(const MessageNodePtr& node,
const std::function<void(const MessageNodePtr&)>& on_translated) {
resetNodesVisited();
// Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm,
// while using translation nodes as barriers (only continue when all inputs are ready)
std::queue<MessageNodePtr> queue;
_node_visited[node->_index] = true;
queue.push(node);
while (!queue.empty()) {
MessageNodePtr current = queue.front();
queue.pop();
for (auto& translation : current->_translations) {
const bool any_output_visited =
std::any_of(translation.node->outputs().begin(), translation.node->outputs().end(), [&](const MessageNodePtr& next_node) {
return _node_visited[next_node->_index];
});
// If any output node has already been visited, skip this translation node (prevents translating
// backwards, from where we came from already)
if (any_output_visited) {
continue;
}
translation.node->setInputReady(translation.input_index);
// Iterate the output nodes only if the translation node is ready
if (translation.node->translate()) {
for (auto &next_node : translation.node->outputs()) {
if (_node_visited[next_node->_index]) {
continue;
}
_node_visited[next_node->_index] = true;
on_translated(next_node);
queue.push(next_node);
}
}
}
}
}
std::optional<MessageNodePtr> findNode(const IdType& id) const {
auto iter = _nodes.find(id);
if (iter == _nodes.end()) {
return std::nullopt;
}
return iter->second;
}
void iterateNodes(const std::function<void(const IdType& type, const MessageNodePtr& node)>& cb) const {
for (const auto& [id, node] : _nodes) {
cb(id, node);
}
}
/**
* Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm
*/
void iterateBFS(const MessageNodePtr& node, const std::function<void(const MessageNodePtr&)>& cb) {
resetNodesVisited();
std::queue<MessageNodePtr> queue;
_node_visited[node->_index] = true;
queue.push(node);
cb(node);
while (!queue.empty()) {
MessageNodePtr current = queue.front();
queue.pop();
for (auto& translation : current->_translations) {
for (auto& next_node : translation.node->outputs()) {
if (_node_visited[next_node->_index]) {
continue;
}
_node_visited[next_node->_index] = true;
queue.push(next_node);
cb(next_node);
}
}
}
}
private:
void resetNodesVisited() {
_node_visited.resize(_nodes.size());
std::fill(_node_visited.begin(), _node_visited.end(), false);
}
std::unordered_map<IdType, MessageNodePtr> _nodes;
std::vector<bool> _node_visited; ///< Cached, to avoid the need to re-allocate on each iteration
};
+39
View File
@@ -0,0 +1,39 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#include <memory>
#include <rclcpp/rclcpp.hpp>
#include "../translations/all_translations.h"
#include "pub_sub_graph.h"
#include "service_graph.h"
#include "monitor.h"
using namespace std::chrono_literals;
class RosTranslationNode : public rclcpp::Node
{
public:
RosTranslationNode() : Node("translation_node")
{
_pub_sub_graph = std::make_unique<PubSubGraph>(*this, RegisteredTranslations::instance().topicTranslations());
_service_graph = std::make_unique<ServiceGraph>(*this, RegisteredTranslations::instance().serviceTranslations());
_monitor = std::make_unique<Monitor>(*this, _pub_sub_graph.get(), _service_graph.get());
}
private:
std::unique_ptr<PubSubGraph> _pub_sub_graph;
std::unique_ptr<ServiceGraph> _service_graph;
rclcpp::TimerBase::SharedPtr _node_update_timer;
std::unique_ptr<Monitor> _monitor;
};
int main(int argc, char * argv[])
{
rclcpp::init(argc, argv);
rclcpp::spin(std::make_shared<RosTranslationNode>());
rclcpp::shutdown();
return 0;
}
+60
View File
@@ -0,0 +1,60 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#include "monitor.h"
using namespace std::chrono_literals;
Monitor::Monitor(rclcpp::Node &node, PubSubGraph* pub_sub_graph, ServiceGraph* service_graph)
: _node(node), _pub_sub_graph(pub_sub_graph), _service_graph(service_graph) {
// Monitor subscriptions & publishers
// TODO: event-based
_node_update_timer = _node.create_wall_timer(1s, [this]() {
updateNow();
});
}
void Monitor::updateNow() {
// Topics
if (_pub_sub_graph != nullptr) {
std::vector<PubSubGraph::TopicInfo> topic_info;
const auto topics = _node.get_topic_names_and_types();
for (const auto &[topic_name, topic_types]: topics) {
auto publishers = _node.get_publishers_info_by_topic(topic_name);
auto subscribers = _node.get_subscriptions_info_by_topic(topic_name);
// Filter out self
int num_publishers = 0;
for (const auto &publisher: publishers) {
num_publishers += publisher.node_name() != _node.get_name();
}
int num_subscribers = 0;
for (const auto &subscriber: subscribers) {
num_subscribers += subscriber.node_name() != _node.get_name();
}
if (num_subscribers > 0 || num_publishers > 0) {
topic_info.emplace_back(PubSubGraph::TopicInfo{topic_name, num_subscribers, num_publishers});
}
}
_pub_sub_graph->updateCurrentTopics(topic_info);
}
// Services
#ifndef DISABLE_SERVICES // ROS Humble does not support the count_services() call
if (_service_graph != nullptr) {
std::vector<ServiceGraph::ServiceInfo> service_info;
const auto services = _node.get_service_names_and_types();
for (const auto& [service_name, service_types] : services) {
const int num_services = _node.get_node_graph_interface()->count_services(service_name);
const int num_clients = _node.get_node_graph_interface()->count_clients(service_name);
// We cannot filter out our own node, as we don't have that info.
// We could use `get_service_names_and_types_by_node`, but then we would not get
// services by non-ros nodes (e.g. microxrce dds bridge)
service_info.emplace_back(ServiceGraph::ServiceInfo{service_name, num_services, num_clients});
}
_service_graph->updateCurrentServices(service_info);
}
#endif
}
+23
View File
@@ -0,0 +1,23 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#pragma once
#include <rclcpp/rclcpp.hpp>
#include "pub_sub_graph.h"
#include "service_graph.h"
#include <functional>
class Monitor {
public:
explicit Monitor(rclcpp::Node &node, PubSubGraph* pub_sub_graph, ServiceGraph* service_graph);
void updateNow();
private:
rclcpp::Node &_node;
PubSubGraph* _pub_sub_graph{nullptr};
ServiceGraph* _service_graph{nullptr};
rclcpp::TimerBase::SharedPtr _node_update_timer;
};
+195
View File
@@ -0,0 +1,195 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#include "pub_sub_graph.h"
#include "util.h"
PubSubGraph::PubSubGraph(rclcpp::Node &node, const TopicTranslations &translations) : _node(node) {
std::unordered_map<std::string, std::set<MessageVersionType>> known_versions;
for (const auto& topic : translations.topics()) {
const std::string full_topic_name = getFullTopicName(_node.get_effective_namespace(), topic.id.topic_name);
_known_topics_warned.insert({full_topic_name, false});
const MessageIdentifier id{full_topic_name, topic.id.version};
NodeDataPubSub node_data{topic.subscription_factory, topic.publication_factory, id, topic.max_serialized_message_size};
_pub_sub_graph.addNodeIfNotExists(id, std::move(node_data), topic.message_buffer);
known_versions[full_topic_name].insert(id.version);
}
auto get_full_topic_names = [this](std::vector<MessageIdentifier> ids) {
for (auto& id : ids) {
id.topic_name = getFullTopicName(_node.get_effective_namespace(), id.topic_name);
}
return ids;
};
for (const auto& translation : translations.translations()) {
const std::vector<MessageIdentifier> inputs = get_full_topic_names(translation.inputs);
const std::vector<MessageIdentifier> outputs = get_full_topic_names(translation.outputs);
_pub_sub_graph.addTranslation(translation.cb, inputs, outputs);
}
printTopicInfo(known_versions);
handleLargestTopic(known_versions);
}
void PubSubGraph::updateCurrentTopics(const std::vector<TopicInfo> &topics) {
_pub_sub_graph.iterateNodes([](const MessageIdentifier& type, const Graph<NodeDataPubSub>::MessageNodePtr& node) {
node->data().has_external_publisher = false;
node->data().has_external_subscriber = false;
node->data().visited = false;
});
for (const auto& info : topics) {
const auto [non_versioned_topic_name, version] = getNonVersionedTopicName(info.topic_name);
auto maybe_node = _pub_sub_graph.findNode({non_versioned_topic_name, version});
if (!maybe_node) {
auto known_topic_iter = _known_topics_warned.find(non_versioned_topic_name);
if (known_topic_iter != _known_topics_warned.end() && !known_topic_iter->second) {
RCLCPP_WARN(_node.get_logger(), "No translation available for version %i of topic %s", version, non_versioned_topic_name.c_str());
known_topic_iter->second = true;
}
continue;
}
const auto& node = maybe_node.value();
if (info.num_publishers > 0) {
node->data().has_external_publisher = true;
}
if (info.num_subscribers > 0) {
node->data().has_external_subscriber = true;
}
}
// Iterate connected graph segments
_pub_sub_graph.iterateNodes([this](const MessageIdentifier& type, const Graph<NodeDataPubSub>::MessageNodePtr& node) {
if (node->data().visited) {
return;
}
node->data().visited = true;
// Count the number of external subscribers and publishers for each connected graph
int num_publishers = 0;
int num_subscribers = 0;
int num_subscribers_without_publisher = 0;
_pub_sub_graph.iterateBFS(node, [&](const Graph<NodeDataPubSub>::MessageNodePtr& node) {
if (node->data().has_external_publisher) {
++num_publishers;
}
if (node->data().has_external_subscriber) {
++num_subscribers;
if (!node->data().has_external_publisher) {
++num_subscribers_without_publisher;
}
}
});
// We need to instantiate publishers and subscribers if:
// - there are multiple publishers and at least 1 subscriber
// - there is 1 publisher and at least 1 subscriber on another node
// Note that in case of splitting or merging topics, this might create more entities than actually needed
const bool require_translation = (num_publishers >= 2 && num_subscribers >= 1)
|| (num_publishers == 1 && num_subscribers_without_publisher >= 1);
if (require_translation) {
_pub_sub_graph.iterateBFS(node, [&](const Graph<NodeDataPubSub>::MessageNodePtr& node) {
node->data().visited = true;
// Has subscriber(s)?
if (node->data().has_external_subscriber && !node->data().publication) {
RCLCPP_INFO(_node.get_logger(), "Found subscriber for topic '%s', version: %i, adding publisher", node->data().topic_name.c_str(), node->data().version);
node->data().publication = node->data().publication_factory(_node);
} else if (!node->data().has_external_subscriber && node->data().publication) {
RCLCPP_INFO(_node.get_logger(), "No subscribers for topic '%s', version: %i, removing publisher", node->data().topic_name.c_str(), node->data().version);
node->data().publication.reset();
}
// Has publisher(s)?
if (node->data().has_external_publisher && !node->data().subscription) {
RCLCPP_INFO(_node.get_logger(), "Found publisher for topic '%s', version: %i, adding subscriber", node->data().topic_name.c_str(), node->data().version);
node->data().subscription = node->data().subscription_factory(_node, [this, node_cpy=node]() {
onSubscriptionUpdate(node_cpy);
});
} else if (!node->data().has_external_publisher && node->data().subscription) {
RCLCPP_INFO(_node.get_logger(), "No publishers for topic '%s', version: %i, removing subscriber", node->data().topic_name.c_str(), node->data().version);
node->data().subscription.reset();
}
});
} else {
// Reset any publishers or subscribers
_pub_sub_graph.iterateBFS(node, [&](const Graph<NodeDataPubSub>::MessageNodePtr& node) {
node->data().visited = true;
if (node->data().publication) {
RCLCPP_INFO(_node.get_logger(), "Removing publisher for topic '%s', version: %i",
node->data().topic_name.c_str(), node->data().version);
node->data().publication.reset();
}
if (node->data().subscription) {
RCLCPP_INFO(_node.get_logger(), "Removing subscriber for topic '%s', version: %i",
node->data().topic_name.c_str(), node->data().version);
node->data().subscription.reset();
}
});
}
});
}
void PubSubGraph::onSubscriptionUpdate(const Graph<NodeDataPubSub>::MessageNodePtr& node) {
_pub_sub_graph.translate(
node,
[this](const Graph<NodeDataPubSub>::MessageNodePtr& node) {
if (node->data().publication != nullptr) {
const auto ret = rcl_publish(node->data().publication->get_publisher_handle().get(),
node->buffer().get(), nullptr);
if (ret != RCL_RET_OK) {
RCLCPP_WARN_ONCE(_node.get_logger(), "Failed to publish on topic '%s', version: %i",
node->data().topic_name.c_str(), node->data().version);
}
}
});
}
void PubSubGraph::printTopicInfo(const std::unordered_map<std::string, std::set<MessageVersionType>>& known_versions) const {
// Print info about known versions
RCLCPP_INFO(_node.get_logger(), "Registered pub/sub topics and versions:");
for (const auto& [topic_name, version_set] : known_versions) {
if (version_set.empty()) {
continue;
}
const std::string versions = std::accumulate(std::next(version_set.begin()), version_set.end(),
std::to_string(*version_set.begin()), // start with first element
[](std::string a, auto&& b) {
return std::move(a) + ", " + std::to_string(b);
});
RCLCPP_INFO(_node.get_logger(), "- %s: %s", topic_name.c_str(), versions.c_str());
}
}
void PubSubGraph::handleLargestTopic(const std::unordered_map<std::string, std::set<MessageVersionType>> &known_versions) {
// FastDDS caches some type information per DDS participant when first creating a publisher or subscriber for a given
// type. The information that is relevant for us is the maximum serialized message size.
// Since different versions can have different sizes, we need to ensure the first publication or subscription
// happens with the version of the largest size. Otherwise, an out-of-memory exception can be triggered.
// And the type must continue to be in use (so we cannot delete it)
for (const auto& [topic_name, versions] : known_versions) {
size_t max_serialized_message_size = 0;
const PublicationFactoryCB* publication_factory_for_max = nullptr;
for (auto version : versions) {
const auto& node = _pub_sub_graph.findNode(MessageIdentifier{topic_name, version});
assert(node);
const auto& node_data = node.value()->data();
if (node_data.max_serialized_message_size > max_serialized_message_size) {
max_serialized_message_size = node_data.max_serialized_message_size;
publication_factory_for_max = &node_data.publication_factory;
}
}
if (publication_factory_for_max) {
_largest_topic_publications.emplace_back((*publication_factory_for_max)(_node));
}
}
}
+58
View File
@@ -0,0 +1,58 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#pragma once
#include <rclcpp/rclcpp.hpp>
#include <utility>
#include "translations.h"
#include "translation_util.h"
#include "graph.h"
class PubSubGraph {
public:
struct TopicInfo {
std::string topic_name; ///< fully qualified topic name (with namespace)
int num_subscribers; ///< does not include this node's subscribers
int num_publishers; ///< does not include this node's publishers
};
PubSubGraph(rclcpp::Node& node, const TopicTranslations& translations);
void updateCurrentTopics(const std::vector<TopicInfo>& topics);
private:
struct NodeDataPubSub {
explicit NodeDataPubSub(SubscriptionFactoryCB subscription_factory, PublicationFactoryCB publication_factory,
const MessageIdentifier& id, size_t max_serialized_message_size)
: subscription_factory(std::move(subscription_factory)), publication_factory(std::move(publication_factory)),
topic_name(id.topic_name), version(id.version), max_serialized_message_size(max_serialized_message_size)
{ }
const SubscriptionFactoryCB subscription_factory;
const PublicationFactoryCB publication_factory;
const std::string topic_name;
const MessageVersionType version;
const size_t max_serialized_message_size;
// Keep track if there's currently a publisher/subscriber
bool has_external_publisher{false};
bool has_external_subscriber{false};
rclcpp::SubscriptionBase::SharedPtr subscription;
rclcpp::PublisherBase::SharedPtr publication;
bool visited{false};
};
void onSubscriptionUpdate(const Graph<NodeDataPubSub>::MessageNodePtr& node);
void printTopicInfo(const std::unordered_map<std::string, std::set<MessageVersionType>>& known_versions) const;
void handleLargestTopic(const std::unordered_map<std::string, std::set<MessageVersionType>>& known_versions);
rclcpp::Node& _node;
Graph<NodeDataPubSub> _pub_sub_graph;
std::unordered_map<std::string, bool> _known_topics_warned;
std::vector<rclcpp::PublisherBase::SharedPtr> _largest_topic_publications;
};
+230
View File
@@ -0,0 +1,230 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#include "service_graph.h"
#include <utility>
using namespace std::chrono_literals;
ServiceGraph::ServiceGraph(rclcpp::Node &node, const ServiceTranslations& translations)
: _node(node) {
std::unordered_map<std::string, std::set<MessageVersionType>> known_versions;
for (const auto& service : translations.nodes()) {
const std::string full_topic_name = getFullTopicName(_node.get_effective_namespace(), service.id.topic_name);
_known_services_warned.insert({full_topic_name, false});
const MessageIdentifier id{full_topic_name, service.id.version};
auto node_data = std::make_shared<NodeDataService>(service, id);
_request_graph.addNodeIfNotExists(id, node_data, service.message_buffer_request);
_response_graph.addNodeIfNotExists(id, node_data, service.message_buffer_response);
known_versions[full_topic_name].insert(id.version);
}
auto get_full_topic_names = [this](std::vector<MessageIdentifier> ids) {
for (auto& id : ids) {
id.topic_name = getFullTopicName(_node.get_effective_namespace(), id.topic_name);
}
return ids;
};
for (const auto& translation : translations.requestTranslations()) {
const std::vector<MessageIdentifier> inputs = get_full_topic_names(translation.inputs);
const std::vector<MessageIdentifier> outputs = get_full_topic_names(translation.outputs);
_request_graph.addTranslation(translation.cb, inputs, outputs);
}
for (const auto& translation : translations.responseTranslations()) {
const std::vector<MessageIdentifier> inputs = get_full_topic_names(translation.inputs);
const std::vector<MessageIdentifier> outputs = get_full_topic_names(translation.outputs);
_response_graph.addTranslation(translation.cb, inputs, outputs);
}
printServiceInfo(known_versions);
handleLargestTopic(known_versions);
_cleanup_timer = _node.create_wall_timer(10s, [this]() {
cleanupStaleRequests();
});
}
void ServiceGraph::updateCurrentServices(const std::vector<ServiceInfo> &services) {
_request_graph.iterateNodes([](const MessageIdentifier& type, const GraphForService::MessageNodePtr& node) {
node->data()->has_service = false;
node->data()->has_client = false;
node->data()->visited = false;
});
for (const auto& info : services) {
const auto [non_versioned_topic_name, version] = getNonVersionedTopicName(info.service_name);
auto maybe_node = _request_graph.findNode({non_versioned_topic_name, version});
if (!maybe_node) {
auto known_topic_iter = _known_services_warned.find(non_versioned_topic_name);
if (known_topic_iter != _known_services_warned.end() && !known_topic_iter->second) {
RCLCPP_WARN(_node.get_logger(), "No translation available for version %i of service %s", version, non_versioned_topic_name.c_str());
known_topic_iter->second = true;
}
continue;
}
const auto& node = maybe_node.value();
if (info.num_services > 0) {
node->data()->has_service = true;
}
if (info.num_clients > 0) {
node->data()->has_client = true;
}
}
// Iterate connected graph segments
_request_graph.iterateNodes([this](const MessageIdentifier& type, const GraphForService::MessageNodePtr& node) {
if (node->data()->visited) {
return;
}
node->data()->visited = true;
// Check if there's a reachable node with a service
int num_services = 0;
_request_graph.iterateBFS(node, [&](const GraphForService::MessageNodePtr& node) {
if (node->data()->has_service && !node->data()->service) {
++num_services;
}
});
// We need to instantiate a service and clients if there's exactly one external service.
if (num_services > 1 ) {
RCLCPP_ERROR_ONCE(_node.get_logger(), "Found %i services for service '%s', skipping this service",
num_services, node->data()->service_name.c_str());
} else if (num_services == 1) {
_request_graph.iterateBFS(node, [&](const GraphForService::MessageNodePtr& node) {
node->data()->visited = true;
if (node->data()->has_service && !node->data()->client && !node->data()->service) {
RCLCPP_INFO(_node.get_logger(), "Found service for '%s', version: %i, adding client", node->data()->service_name.c_str(), node->data()->version);
auto tuple = node->data()->client_factory(_node, [this, tmp_node=node](rmw_request_id_t& request) {
onResponse(request, tmp_node);
});
node->data()->client = std::get<0>(tuple);
node->data()->client_send_cb = std::get<1>(tuple);
} else if (!node->data()->has_service && !node->data()->service && node->data()->has_client) {
RCLCPP_INFO(_node.get_logger(), "Found client for '%s', version: %i, adding service", node->data()->service_name.c_str(), node->data()->version);
node->data()->service = node->data()->service_factory(_node, [this, tmp_node=node](std::shared_ptr<rmw_request_id_t> req_id) {
onNewRequest(std::move(req_id), tmp_node);
});
}
});
} else {
// Reset any service or client
_request_graph.iterateBFS(node, [&](const GraphForService::MessageNodePtr& node) {
node->data()->visited = true;
if (node->data()->service) {
RCLCPP_INFO(_node.get_logger(), "Removing service for '%s', version: %i",
node->data()->service_name.c_str(), node->data()->version);
node->data()->service.reset();
}
if (node->data()->client) {
RCLCPP_INFO(_node.get_logger(), "Removing client for '%s', version: %i",
node->data()->service_name.c_str(), node->data()->version);
node->data()->client.reset();
}
});
}
});
}
void ServiceGraph::printServiceInfo(const std::unordered_map<std::string, std::set<MessageVersionType>>& known_versions) const {
// Print info about known versions
RCLCPP_INFO(_node.get_logger(), "Registered services and versions:");
for (const auto& [topic_name, version_set] : known_versions) {
if (version_set.empty()) {
continue;
}
const std::string versions = std::accumulate(std::next(version_set.begin()), version_set.end(),
std::to_string(*version_set.begin()), // start with first element
[](std::string a, auto&& b) {
return std::move(a) + ", " + std::to_string(b);
});
RCLCPP_INFO(_node.get_logger(), "- %s: %s", topic_name.c_str(), versions.c_str());
}
}
void ServiceGraph::handleLargestTopic(const std::unordered_map<std::string, std::set<MessageVersionType>> &known_versions) {
// See PubSubGraph::handleLargestTopic for an explanation why this is needed
unsigned index = 0;
for (const auto& [topic_name, versions] : known_versions) {
std::array<size_t, 2> max_serialized_message_size{0, 0};
std::array<const NamedPublicationFactoryCB*, 2> publication_factory_for_max{nullptr, nullptr};
for (auto version : versions) {
const auto& node = _request_graph.findNode(MessageIdentifier{topic_name, version});
assert(node);
const auto& node_data = node.value()->data();
for (unsigned i = 0; i < max_serialized_message_size.size(); ++i) {
if (node_data->max_serialized_message_size[i] > max_serialized_message_size[i]) {
max_serialized_message_size[i] = node_data->max_serialized_message_size[i];
publication_factory_for_max[i] = &node_data->publication_factory[i];
}
}
}
for (unsigned i = 0; i < max_serialized_message_size.size(); ++i) {
if (publication_factory_for_max[i]) {
const std::string tmp_topic_name = "dummy_topic" + std::to_string(index++);
_largest_topic_publications.emplace_back((*publication_factory_for_max[i])(_node, tmp_topic_name));
}
}
}
}
void ServiceGraph::onNewRequest(std::shared_ptr<rmw_request_id_t> req_id, GraphForService::MessageNodePtr node) {
bool service_called = false;
_request_graph.translate(node, [this, &service_called, &req_id, original_node=node](const GraphForService::MessageNodePtr& node) {
if (node->data()->client && node->data()->client_send_cb && !service_called) {
service_called = true;
const int64_t client_request_id = node->data()->client_send_cb(node->buffer());
node->data()->ongoing_requests[client_request_id] = Request{req_id, original_node->data(), _node.now()};
}
});
}
void ServiceGraph::onResponse(rmw_request_id_t &req_id, GraphForService::MessageNodePtr node) {
auto iter = node->data()->ongoing_requests.find(req_id.sequence_number);
if (iter == node->data()->ongoing_requests.end()) {
RCLCPP_ERROR(_node.get_logger(), "Got response with unknown request %li", req_id.sequence_number);
return;
}
bool service_called = false;
auto response_node = _response_graph.findNode({node->data()->service_name, node->data()->version});
assert(response_node);
_response_graph.translate(response_node.value(), [this, &service_called, &iter](const GraphForService::MessageNodePtr &node) {
if (node->data()->service && !service_called && iter->second.original_node_data == node->data()) {
const rcl_ret_t ret = rcl_send_response(node->data()->service->get_service_handle().get(),
iter->second.original_request_id.get(), node->buffer().get());
if (ret != RCL_RET_OK) {
RCLCPP_ERROR(_node.get_logger(), "Failed to send response: %s", rcl_get_error_string().str);
}
service_called = true;
}
});
node->data()->ongoing_requests.erase(iter);
}
void ServiceGraph::cleanupStaleRequests() {
static const auto kRequestTimeout = 20s;
_request_graph.iterateNodes([this](const MessageIdentifier& type, const GraphForService::MessageNodePtr& node) {
for (auto it = node->data()->ongoing_requests.begin(); it != node->data()->ongoing_requests.end();) {
const auto& request = it->second;
if (_node.now() - request.timestamp_received > kRequestTimeout) {
RCLCPP_INFO(_node.get_logger(), "Request timed out, dropping ongoing request for '%s', version: %i, request id: %li",
node->data()->service_name.c_str(), node->data()->version, request.original_request_id->sequence_number);
it = node->data()->ongoing_requests.erase(it);
} else {
++it;
}
}
});
}
+76
View File
@@ -0,0 +1,76 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#pragma once
#include <rclcpp/rclcpp.hpp>
#include <utility>
#include "translations.h"
#include "translation_util.h"
#include "graph.h"
class ServiceGraph {
public:
struct ServiceInfo {
std::string service_name; ///< fully qualified service name (with namespace)
int num_services; ///< This can include a service created by the translation node
int num_clients; ///< This can include a client created by the translation node
};
ServiceGraph(rclcpp::Node &node, const ServiceTranslations& translations);
void updateCurrentServices(const std::vector<ServiceInfo>& services);
private:
struct NodeDataService;
using GraphForService = Graph<std::shared_ptr<NodeDataService>>;
void printServiceInfo(const std::unordered_map<std::string, std::set<MessageVersionType>> &known_versions) const;
void handleLargestTopic(const std::unordered_map<std::string, std::set<MessageVersionType>>& known_versions);
void onNewRequest(std::shared_ptr<rmw_request_id_t> req_id, GraphForService::MessageNodePtr node);
void onResponse(rmw_request_id_t& req_id, GraphForService::MessageNodePtr node);
void cleanupStaleRequests();
struct Request {
std::shared_ptr<rmw_request_id_t> original_request_id;
std::shared_ptr<NodeDataService> original_node_data{nullptr};
rclcpp::Time timestamp_received;
};
struct NodeDataService {
explicit NodeDataService(const Service& service, const MessageIdentifier& id)
: service_factory(service.service_factory), client_factory(service.client_factory),
service_name(id.topic_name), version(id.version),
publication_factory{service.publication_factory_request, service.publication_factory_response},
max_serialized_message_size{service.max_serialized_message_size_request, service.max_serialized_message_size_response}
{ }
const ServiceFactoryCB service_factory;
const ClientFactoryCB client_factory;
const std::string service_name;
const MessageVersionType version;
const std::array<NamedPublicationFactoryCB, 2> publication_factory; // Request/Response
const std::array<size_t, 2> max_serialized_message_size;
// Keep track if there's currently a client/service
bool has_service{false};
bool has_client{false};
rclcpp::ClientBase::SharedPtr client;
ClientSendCB client_send_cb;
rclcpp::ServiceBase::SharedPtr service;
std::unordered_map<int64_t, Request> ongoing_requests; ///< Ongoing service calls for this node
bool visited{false};
};
rclcpp::Node& _node;
GraphForService _request_graph;
GraphForService _response_graph;
std::unordered_map<std::string, bool> _known_services_warned;
rclcpp::TimerBase::SharedPtr _cleanup_timer;
std::vector<rclcpp::PublisherBase::SharedPtr> _largest_topic_publications;
};
+64
View File
@@ -0,0 +1,64 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#pragma once
#include <memory>
#include <utility>
#include <type_traits>
#include <vector>
/**
* Helper struct to store template parameter packs
*/
template <typename... Args>
struct Pack {
};
/**
* Struct for a template parameter pack with access to the individual types
*/
template<typename ...Types>
struct TypesArray {
template<typename T, typename...OtherTypes>
struct TypeHelper {
using Type = T;
using Next = TypeHelper<OtherTypes..., void>;
};
using Type1 = typename TypeHelper<Types...>::Type;
using Type2 = typename TypeHelper<Types...>::Next::Type;
using Type3 = typename TypeHelper<Types...>::Next::Next::Type;
using Type4 = typename TypeHelper<Types...>::Next::Next::Next::Type;
using Type5 = typename TypeHelper<Types...>::Next::Next::Next::Next::Type;
using Type6 = typename TypeHelper<Types...>::Next::Next::Next::Next::Next::Type;
using args = Pack<Types...>;
};
/**
* Helper for call_translation_function()
*/
template<typename F, typename MessageType, typename... ArgsIn, typename... ArgsOut, size_t... Is, size_t... Os>
inline void call_translation_function_impl(F f, Pack<ArgsIn...>, Pack<ArgsOut...>,
const std::vector<std::shared_ptr<MessageType>>& messages_in,
std::vector<std::shared_ptr<MessageType>>& messages_out,
std::integer_sequence<size_t, Is...>, std::integer_sequence<size_t, Os...>)
{
f(*static_cast<const ArgsIn*>(messages_in[Is].get())..., *static_cast<ArgsOut*>(messages_out[Os].get())...);
}
/**
* Call a translation function F which takes the arguments (const ArgsIn&..., ArgsOut&...),
* by passing messages_in and messages_out as arguments.
* Note that sizeof(ArgsIn) == messages_in.length() && sizeof(ArgsOut) == messages_out.length() must hold.
*/
template<typename F, typename MessageType, typename... ArgsIn, typename... ArgsOut>
inline void call_translation_function(F f, Pack<ArgsIn...> pack_in, Pack<ArgsOut...> pack_out,
const std::vector<std::shared_ptr<MessageType>>& messages_in,
std::vector<std::shared_ptr<MessageType>>& messages_out) {
call_translation_function_impl(f, pack_in, pack_out, messages_in, messages_out,
std::index_sequence_for<ArgsIn...>{}, std::index_sequence_for<ArgsOut...>{});
}
+386
View File
@@ -0,0 +1,386 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#pragma once
#include "translations.h"
#include "util.h"
#include "template_util.h"
#include <rosidl_typesupport_cpp/message_type_support_dispatch.hpp>
#include <rosidl_typesupport_fastrtps_cpp/message_type_support.h>
class RegisteredTranslations {
public:
RegisteredTranslations(RegisteredTranslations const&) = delete;
void operator=(RegisteredTranslations const&) = delete;
static RegisteredTranslations& instance() {
static RegisteredTranslations instance;
return instance;
}
/**
* @brief Register a translation class with 1 input and 1 output message.
*
* The translation class has the form:
*
* ```
* class MyTranslation {
* public:
* using MessageOlder = px4_msgs_old::msg::VehicleAttitudeV2;
*
* using MessageNewer = px4_msgs::msg::VehicleAttitude;
*
* static constexpr const char* kTopic = "fmu/out/vehicle_attitude";
*
* static void fromOlder(const MessageOlder &msg_older, MessageNewer &msg_newer) {
* // set msg_newer from msg_older
* }
*
* static void toOlder(const MessageNewer &msg_newer, MessageOlder &msg_older) {
* // set msg_older from msg_newer
* }
* };
* ```
*/
template<class T>
void registerDirectTranslation() {
const std::string topic_name = T::kTopic;
_topic_translations.addTopic(getTopicForMessageType<typename T::MessageOlder>(topic_name));
_topic_translations.addTopic(getTopicForMessageType<typename T::MessageNewer>(topic_name));
// Translation callbacks
auto translation_cb_from_older = [](const std::vector<MessageBuffer>& older_msg, std::vector<MessageBuffer>& newer_msg) {
T::fromOlder(*(const typename T::MessageOlder*)older_msg[0].get(), *(typename T::MessageNewer*)newer_msg[0].get());
};
auto translation_cb_to_older = [](const std::vector<MessageBuffer>& newer_msg, std::vector<MessageBuffer>& older_msg) {
T::toOlder(*(const typename T::MessageNewer*)newer_msg[0].get(), *(typename T::MessageOlder*)older_msg[0].get());
};
_topic_translations.addTranslation({translation_cb_from_older,
{MessageIdentifier{topic_name, T::MessageOlder::MESSAGE_VERSION}},
{MessageIdentifier{topic_name, T::MessageNewer::MESSAGE_VERSION}}});
_topic_translations.addTranslation({translation_cb_to_older,
{MessageIdentifier{topic_name, T::MessageNewer::MESSAGE_VERSION}},
{MessageIdentifier{topic_name, T::MessageOlder::MESSAGE_VERSION}}});
}
/**
* @brief Register a translation class for a service.
*
* The translation class has the form:
*
* ```
* class MyServiceTranslation {
* public:
* using MessageOlder = px4_msgs_old::srv::VehicleCommandV0;
* using MessageNewer = px4_msgs::srv::VehicleCommand;
*
* static constexpr const char* kTopic = "fmu/vehicle_command";
*
* static void fromOlder(const MessageOlder::Request &msg_older, MessageNewer::Request &msg_newer) {
* // set msg_newer from msg_older
* }
*
* static void toOlder(const MessageNewer::Request &msg_newer, MessageOlder::Request &msg_older) {
* // set msg_older from msg_newer
* }
*
* static void fromOlder(const MessageOlder::Response &msg_older, MessageNewer::Response &msg_newer) {
* // set msg_newer from msg_older
* }
*
* static void toOlder(const MessageNewer::Response &msg_newer, MessageOlder::Response &msg_older) {
* // set msg_older from msg_newer
* }
* };
* ```
*/
template<class T>
void registerServiceDirectTranslation() {
const std::string topic_name = T::kTopic;
_service_translations.addNode(getServiceForMessageType<typename T::MessageOlder>(topic_name));
_service_translations.addNode(getServiceForMessageType<typename T::MessageNewer>(topic_name));
// Add translations
{ // Request
auto translation_cb_from_older = [](const std::vector<MessageBuffer> &older_msg,
std::vector<MessageBuffer> &newer_msg) {
T::fromOlder(*(const typename T::MessageOlder::Request *) older_msg[0].get(),
*(typename T::MessageNewer::Request *) newer_msg[0].get());
};
auto translation_cb_to_older = [](const std::vector<MessageBuffer> &newer_msg,
std::vector<MessageBuffer> &older_msg) {
T::toOlder(*(const typename T::MessageNewer::Request *) newer_msg[0].get(),
*(typename T::MessageOlder::Request *) older_msg[0].get());
};
_service_translations.addRequestTranslation({translation_cb_from_older,
{MessageIdentifier{topic_name, T::MessageOlder::Request::MESSAGE_VERSION}},
{MessageIdentifier{topic_name, T::MessageNewer::Request::MESSAGE_VERSION}}});
_service_translations.addRequestTranslation({translation_cb_to_older,
{MessageIdentifier{topic_name, T::MessageNewer::Request::MESSAGE_VERSION}},
{MessageIdentifier{topic_name, T::MessageOlder::Request::MESSAGE_VERSION}}});
}
{ // Response
auto translation_cb_from_older = [](const std::vector<MessageBuffer> &older_msg,
std::vector<MessageBuffer> &newer_msg) {
T::fromOlder(*(const typename T::MessageOlder::Response *) older_msg[0].get(),
*(typename T::MessageNewer::Response *) newer_msg[0].get());
};
auto translation_cb_to_older = [](const std::vector<MessageBuffer> &newer_msg,
std::vector<MessageBuffer> &older_msg) {
T::toOlder(*(const typename T::MessageNewer::Response *) newer_msg[0].get(),
*(typename T::MessageOlder::Response *) older_msg[0].get());
};
_service_translations.addResponseTranslation({translation_cb_from_older,
{MessageIdentifier{topic_name, T::MessageOlder::Request::MESSAGE_VERSION}},
{MessageIdentifier{topic_name, T::MessageNewer::Request::MESSAGE_VERSION}}});
_service_translations.addResponseTranslation({translation_cb_to_older,
{MessageIdentifier{topic_name, T::MessageNewer::Request::MESSAGE_VERSION}},
{MessageIdentifier{topic_name, T::MessageOlder::Request::MESSAGE_VERSION}}});
}
}
/**
* @brief Register a translation class with N input and M output messages.
*
* The translation class has the form:
* ```
* class MyTranslation {
* public:
* using MessagesOlder = TypesArray<ROS_MSG_OLDER_1, ROS_MSG_OLDER_2, ...>;
* static constexpr const char* kTopicsOlder[] = {
* "fmu/out/vehicle_global_position",
* "fmu/out/vehicle_local_position",
* ...
* };
*
* using MessagesNewer = TypesArray<ROS_MSG_NEWER_1, ROS_MSG_NEWER_2, ...>;
* static constexpr const char* kTopicsNewer[] = {
* "fmu/out/vehicle_global_position",
* "fmu/out/vehicle_local_position",
* ...
* };
*
* static void fromOlder(const MessagesOlder::Type1 &msg_older1, const MessagesOlder::Type2 &msg_older2, ...
* MessagesNewer::Type1 &msg_newer1, MessagesNewer::Type2 &msg_newer2, ...) {
* // Set msg_newerX from msg_olderX
* }
*
* static void toOlder(const MessagesNewer::Type1 &msg_newer1, const MessagesNewer::Type2 &msg_newer2, ...
* MessagesOlder::Type1 &msg_older1, MessagesOlder::Type2 &msg_older2, ...) {
* // Set msg_olderX from msg_newerX
* }
* };
* ```
*/
template<class T>
void registerTranslation() {
const auto topics_older = getTopicsForMessageType(typename T::MessagesOlder::args(), T::kTopicsOlder);
std::vector<MessageIdentifier> topics_older_identifiers;
for (const auto& topic : topics_older) {
_topic_translations.addTopic(topic);
topics_older_identifiers.emplace_back(topic.id);
}
const auto topics_newer = getTopicsForMessageType(typename T::MessagesNewer::args(),T::kTopicsNewer);
std::vector<MessageIdentifier> topics_newer_identifiers;
for (const auto& topic : topics_newer) {
_topic_translations.addTopic(topic);
topics_newer_identifiers.emplace_back(topic.id);
}
// Translation callbacks
const auto translation_cb_from_older = [](const std::vector<MessageBuffer>& older_msgs, std::vector<MessageBuffer>& newer_msgs) {
call_translation_function(&T::fromOlder, typename T::MessagesOlder::args(), typename T::MessagesNewer::args(), older_msgs, newer_msgs);
};
const auto translation_cb_to_older = [](const std::vector<MessageBuffer>& newer_msgs, std::vector<MessageBuffer>& older_msgs) {
call_translation_function(&T::toOlder, typename T::MessagesNewer::args(), typename T::MessagesOlder::args(), newer_msgs, older_msgs);
};
{
// Older -> Newer
Translation translation;
translation.cb = translation_cb_from_older;
translation.inputs = topics_older_identifiers;
translation.outputs = topics_newer_identifiers;
_topic_translations.addTranslation(std::move(translation));
}
{
// Newer -> Older
Translation translation;
translation.cb = translation_cb_to_older;
translation.inputs = topics_newer_identifiers;
translation.outputs = topics_older_identifiers;
_topic_translations.addTranslation(std::move(translation));
}
}
const TopicTranslations& topicTranslations() const { return _topic_translations; }
const ServiceTranslations& serviceTranslations() const { return _service_translations; }
protected:
RegisteredTranslations() = default;
private:
template<typename RosMessageType>
static size_t getMaxSerializedMessageSize() {
const auto type_handle = rclcpp::get_message_type_support_handle<RosMessageType>();
const auto fastrtps_handle = rosidl_typesupport_cpp::get_message_typesupport_handle_function(&type_handle, "rosidl_typesupport_fastrtps_cpp");
if (fastrtps_handle) {
const auto *callbacks = static_cast<const message_type_support_callbacks_t *>(fastrtps_handle->data);
char bound_info;
return callbacks->max_serialized_size(bound_info);
}
return 0;
}
template<typename RosMessageType>
static Topic getTopicForMessageType(const std::string& topic_name) {
Topic ret{};
ret.id.topic_name = topic_name;
ret.id.version = RosMessageType::MESSAGE_VERSION;
auto message_buffer = std::make_shared<RosMessageType>();
ret.message_buffer = std::static_pointer_cast<void>(message_buffer);
// Subscription/Publication factory methods
const std::string topic_name_versioned = getVersionedTopicName(topic_name, ret.id.version);
ret.subscription_factory = [topic_name_versioned, message_buffer](rclcpp::Node& node,
const std::function<void()>& on_topic_cb) -> rclcpp::SubscriptionBase::SharedPtr {
return std::dynamic_pointer_cast<rclcpp::SubscriptionBase>(
// Note: template instantiation of subscriptions slows down compilation considerably, see
// https://github.com/ros2/rclcpp/issues/1949
node.create_subscription<RosMessageType>(topic_name_versioned, rclcpp::QoS(1).best_effort(),
[on_topic_cb=on_topic_cb, message_buffer](typename RosMessageType::UniquePtr msg) -> void {
*message_buffer = *msg;
on_topic_cb();
}));
};
ret.publication_factory = [topic_name_versioned](rclcpp::Node& node) -> rclcpp::PublisherBase::SharedPtr {
return std::dynamic_pointer_cast<rclcpp::PublisherBase>(
node.create_publisher<RosMessageType>(topic_name_versioned, rclcpp::QoS(1).best_effort()));
};
ret.max_serialized_message_size = getMaxSerializedMessageSize<RosMessageType>();
return ret;
}
template<typename RosMessageType>
static Service getServiceForMessageType(const std::string& topic_name) {
Service ret{};
ret.id.topic_name = topic_name;
ret.id.version = RosMessageType::Request::MESSAGE_VERSION;
auto message_buffer_request = std::make_shared<typename RosMessageType::Request>();
ret.message_buffer_request = std::static_pointer_cast<void>(message_buffer_request);
auto message_buffer_response = std::make_shared<typename RosMessageType::Response>();
ret.message_buffer_response = std::static_pointer_cast<void>(message_buffer_response);
// Service/client factory methods
const std::string topic_name_versioned = getVersionedTopicName(topic_name, ret.id.version);
ret.service_factory = [topic_name_versioned, message_buffer_request](rclcpp::Node& node,
const std::function<void(std::shared_ptr<rmw_request_id_t> req_id)>& on_request_cb) -> rclcpp::ServiceBase::SharedPtr {
return std::dynamic_pointer_cast<rclcpp::ServiceBase>(
node.create_service<RosMessageType>(topic_name_versioned,
[on_request_cb=on_request_cb, message_buffer_request](
typename rclcpp::Service<RosMessageType>::SharedPtr service,
std::shared_ptr<rmw_request_id_t> req_id,
const std::shared_ptr<typename RosMessageType::Request> request
) -> void {
*message_buffer_request = *request;
on_request_cb(std::move(req_id));
}));
};
ret.client_factory = [topic_name_versioned, message_buffer_response](rclcpp::Node& node,
const std::function<void(rmw_request_id_t&)>& on_response_cb) {
auto client = node.create_client<RosMessageType>(topic_name_versioned);
client->set_on_new_response_callback([client, message_buffer_response, on_response_cb](size_t num) {
for (size_t i = 0; i < num; i++) {
rmw_request_id_t request_id{};
if (client->take_response(*message_buffer_response, request_id)) {
on_response_cb(request_id);
}
}
});
const auto send_request = [client](MessageBuffer request) {
auto result = client->async_send_request(std::static_pointer_cast<typename RosMessageType::Request>(request));
// We don't need the client to keep track of ongoing requests, so we remove it right away
// to prevent leaks
client->remove_pending_request(result.request_id);
return result.request_id;
};
return std::make_tuple(std::dynamic_pointer_cast<rclcpp::ClientBase>(client), send_request);
};
ret.publication_factory_request = [](rclcpp::Node& node, const std::string& topic_name) -> rclcpp::PublisherBase::SharedPtr {
return std::dynamic_pointer_cast<rclcpp::PublisherBase>(
node.create_publisher<typename RosMessageType::Request>(
topic_name,rclcpp::QoS(1).best_effort().avoid_ros_namespace_conventions(true)));
};
ret.publication_factory_response = [](rclcpp::Node& node, const std::string& topic_name) -> rclcpp::PublisherBase::SharedPtr {
return std::dynamic_pointer_cast<rclcpp::PublisherBase>(
node.create_publisher<typename RosMessageType::Response>(
topic_name,rclcpp::QoS(1).best_effort().avoid_ros_namespace_conventions(true)));
};
ret.max_serialized_message_size_request = getMaxSerializedMessageSize<typename RosMessageType::Request>();
ret.max_serialized_message_size_response = getMaxSerializedMessageSize<typename RosMessageType::Response>();
return ret;
}
template<typename... RosMessageTypes, size_t... Is>
static std::vector<Topic> getTopicsForMessageTypeImpl(const char* const topics[], std::integer_sequence<size_t, Is...>) {
std::vector<Topic> ret {
getTopicForMessageType<RosMessageTypes>(topics[Is])...
};
return ret;
}
template<typename... RosMessageTypes, size_t N>
static std::vector<Topic> getTopicsForMessageType(Pack<RosMessageTypes...>, const char* const (&topics)[N]) {
static_assert(N == sizeof...(RosMessageTypes), "Number of topics does not match number of message types");
return getTopicsForMessageTypeImpl<RosMessageTypes...>(topics, std::index_sequence_for<RosMessageTypes...>{});
}
TopicTranslations _topic_translations;
ServiceTranslations _service_translations;
};
template<class T>
class RegistrationHelperDirect {
public:
explicit RegistrationHelperDirect(const char* dummy) {
// There's something strange: when there is no argument passed, the
// compiler removes the static object completely. I don't know
// why but this dummy variable prevents that.
(void)dummy;
RegisteredTranslations::instance().registerDirectTranslation<T>();
}
explicit RegistrationHelperDirect(const char* dummy, bool for_service) {
(void)dummy;
RegisteredTranslations::instance().registerServiceDirectTranslation<T>();
}
RegistrationHelperDirect(RegistrationHelperDirect const&) = delete;
void operator=(RegistrationHelperDirect const&) = delete;
};
#define REGISTER_TOPIC_TRANSLATION_DIRECT(class_name) \
RegistrationHelperDirect<class_name> class_name##_registration_direct("dummy");
#define REGISTER_SERVICE_TRANSLATION_DIRECT(class_name) \
RegistrationHelperDirect<class_name> class_name##_service_registration_direct("dummy", true);
template<class T>
class TopicRegistrationHelperGeneric {
public:
explicit TopicRegistrationHelperGeneric(const char* dummy) {
(void)dummy;
RegisteredTranslations::instance().registerTranslation<T>();
}
TopicRegistrationHelperGeneric(TopicRegistrationHelperGeneric const&) = delete;
void operator=(TopicRegistrationHelperGeneric const&) = delete;
};
#define REGISTER_TOPIC_TRANSLATION(class_name) \
TopicRegistrationHelperGeneric<class_name> class_name##_registration_generic("dummy");
@@ -0,0 +1,5 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#include "translations.h"
+91
View File
@@ -0,0 +1,91 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#pragma once
#include <string>
#include <cstdint>
#include <unordered_map>
#include <utility>
#include <functional>
#include <vector>
#include <memory>
#include <tuple>
#include "util.h"
#include "graph.h"
#include <rclcpp/rclcpp.hpp>
using TranslationCB = std::function<void(const std::vector<MessageBuffer>&, std::vector<MessageBuffer>&)>;
using SubscriptionFactoryCB = std::function<rclcpp::SubscriptionBase::SharedPtr(rclcpp::Node&, const std::function<void()>& on_topic_cb)>;
using PublicationFactoryCB = std::function<rclcpp::PublisherBase::SharedPtr(rclcpp::Node&)>;
using NamedPublicationFactoryCB = std::function<rclcpp::PublisherBase::SharedPtr(rclcpp::Node&, const std::string&)>;
using ServiceFactoryCB = std::function<rclcpp::ServiceBase::SharedPtr(rclcpp::Node&, const std::function<void(std::shared_ptr<rmw_request_id_t> req_id)>& on_request_cb)>;
using ClientSendCB = std::function<int64_t(MessageBuffer)>;
using ClientFactoryCB = std::function<std::tuple<rclcpp::ClientBase::SharedPtr, ClientSendCB>(rclcpp::Node&, const std::function<void(rmw_request_id_t&)>& on_response_cb)>;
struct Topic {
MessageIdentifier id;
SubscriptionFactoryCB subscription_factory;
PublicationFactoryCB publication_factory;
std::shared_ptr<void> message_buffer;
size_t max_serialized_message_size{};
};
struct Service {
MessageIdentifier id;
ServiceFactoryCB service_factory;
ClientFactoryCB client_factory;
NamedPublicationFactoryCB publication_factory_request;
NamedPublicationFactoryCB publication_factory_response;
std::shared_ptr<void> message_buffer_request;
size_t max_serialized_message_size_request{};
std::shared_ptr<void> message_buffer_response;
size_t max_serialized_message_size_response{};
};
struct Translation {
TranslationCB cb;
std::vector<MessageIdentifier> inputs;
std::vector<MessageIdentifier> outputs;
};
class TopicTranslations {
public:
TopicTranslations() = default;
void addTopic(Topic topic) { _topics.push_back(std::move(topic)); }
void addTranslation(Translation translation) { _translations.push_back(std::move(translation)); }
const std::vector<Topic>& topics() const { return _topics; }
const std::vector<Translation>& translations() const { return _translations; }
private:
std::vector<Topic> _topics;
std::vector<Translation> _translations;
};
class ServiceTranslations {
public:
ServiceTranslations() = default;
void addNode(Service node) { _nodes.push_back(std::move(node)); }
void addRequestTranslation(Translation translation) { _request_translations.push_back(std::move(translation)); }
void addResponseTranslation(Translation translation) { _response_translations.push_back(std::move(translation)); }
const std::vector<Service>& nodes() const { return _nodes; }
const std::vector<Translation>& requestTranslations() const { return _request_translations; }
const std::vector<Translation>& responseTranslations() const { return _response_translations; }
private:
std::vector<Service> _nodes;
std::vector<Translation> _request_translations;
std::vector<Translation> _response_translations;
};
+51
View File
@@ -0,0 +1,51 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#pragma once
#include <string>
#include <cstdint>
using MessageVersionType = uint32_t;
static inline std::string getVersionedTopicName(const std::string& topic_name, MessageVersionType version) {
// version == 0 can be used to transition from non-versioned topics to versioned ones
if (version == 0) {
return topic_name;
}
return topic_name + "_v" + std::to_string(version);
}
static inline std::pair<std::string, MessageVersionType> getNonVersionedTopicName(const std::string& topic_name) {
// topic name has the form <name>_v<version>, or just <name> (with version=0)
auto pos = topic_name.find_last_of("_v");
// Ensure there's at least one more char after the found string
if (pos == std::string::npos || pos + 2 > topic_name.length()) {
return std::make_pair(topic_name, 0);
}
std::string non_versioned_topic_name = topic_name.substr(0, pos - 1);
std::string version = topic_name.substr(pos + 1);
// Ensure only digits are in the version string
for (char c : version) {
if (!std::isdigit(c)) {
return std::make_pair(topic_name, 0);
}
}
return std::make_pair(non_versioned_topic_name, std::stol(version));
}
/**
* Get the full topic name, including namespace from a topic name.
* namespace_name should be set to Node::get_effective_namespace()
*/
static inline std::string getFullTopicName(const std::string& namespace_name, const std::string& topic_name) {
std::string full_topic_name = topic_name;
if (!full_topic_name.empty() && full_topic_name[0] != '/') {
if (namespace_name.empty() || namespace_name.back() != '/') {
full_topic_name = '/' + full_topic_name;
}
full_topic_name = namespace_name + full_topic_name;
}
return full_topic_name;
}