diff --git a/resources/external_control.urscript b/resources/external_control.urscript index 6e29f034f..e956fa9aa 100644 --- a/resources/external_control.urscript +++ b/resources/external_control.urscript @@ -1003,7 +1003,8 @@ while control_mode > MODE_STOPPED: end else: textmsg("Socket timed out waiting for command on reverse_socket. The script will exit now.") - control_mode = MODE_STOPPED + stopj(STOPJ_ACCELERATION) + halt end exit_critical end diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 593f50f13..0aff4cdb4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -82,6 +82,22 @@ if (INTEGRATION_TESTS) TEST_SUFFIX _headless ) + # ExternalControlProgram tests + add_executable(external_control_program_tests_urcap test_external_control_program.cpp) + target_link_libraries(external_control_program_tests_urcap PRIVATE ur_client_library::urcl GTest::gtest_main) + gtest_add_tests(TARGET external_control_program_tests_urcap + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + EXTRA_ARGS --headless false + TEST_SUFFIX _urcap + ) + add_executable(external_control_program_tests_headless test_external_control_program.cpp) + target_link_libraries(external_control_program_tests_headless PRIVATE ur_client_library::urcl GTest::gtest_main) + gtest_add_tests(TARGET external_control_program_tests_headless + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + EXTRA_ARGS --headless true + TEST_SUFFIX _headless + ) + # InstructionExecutor tests add_executable(instruction_executor_test_urcap test_instruction_executor.cpp) target_link_libraries(instruction_executor_test_urcap PRIVATE ur_client_library::urcl GTest::gtest_main) diff --git a/tests/test_external_control_program.cpp b/tests/test_external_control_program.cpp new file mode 100644 index 000000000..5c45f60ba --- /dev/null +++ b/tests/test_external_control_program.cpp @@ -0,0 +1,140 @@ +// -- BEGIN LICENSE BLOCK ---------------------------------------------- +// Copyright 2026 Universal Robots A/S +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// * Neither the name of the {copyright_holder} nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// -- END LICENSE BLOCK ------------------------------------------------ + +#include + +#include "test_utils.h" +#include "ur_client_library/example_robot_wrapper.h" + +using namespace urcl; +const std::string SCRIPT_FILE = "../resources/external_control.urscript"; +const std::string OUTPUT_RECIPE = "resources/rtde_output_recipe.txt"; +const std::string INPUT_RECIPE = "resources/rtde_input_recipe.txt"; +std::string g_ROBOT_IP = "192.168.56.101"; +bool g_HEADLESS = true; + +std::unique_ptr g_my_robot; + +class ExternalControlProgramTest : public ::testing::Test +{ +protected: + static void SetUpTestSuite() + { + if (!(robotVersionLessThan(g_ROBOT_IP, "10.0.0") || g_HEADLESS)) + { + GTEST_SKIP_("Running URCap tests for PolyScope X is currently not supported."); + } + } + void SetUp() override + { + std::string modified_script_path = extendScript(SCRIPT_FILE); + + g_my_robot = std::make_unique(g_ROBOT_IP, OUTPUT_RECIPE, INPUT_RECIPE, g_HEADLESS, + "external_control.urp", modified_script_path); + if (!g_my_robot->isHealthy()) + { + ASSERT_TRUE(g_my_robot->resendRobotProgram()); + ASSERT_TRUE(g_my_robot->waitForProgramRunning(500)); + } + server_.reset(new TestableTcpServer(60005)); + server_->start(); + } + + void TearDown() override + { + server_.reset(); + } + + std::string extendScript(const std::string& script_path) + { + char modified_script_path[] = "urscript.XXXXXX"; +#ifdef _WIN32 +# define mkstemp _mktemp_s +#endif + std::ignore = mkstemp(modified_script_path); + + std::ofstream ofs(modified_script_path); + if (ofs.bad()) + { + std::cout << "Failed to create temporary files" << std::endl; + throw std::runtime_error("Failed to create temporary files"); + } + std::ifstream in_file(script_path); + std::string prog((std::istreambuf_iterator(in_file)), (std::istreambuf_iterator())); + prog += "\nsocket_open(\"{{SERVER_IP_REPLACE}}\", 60005, \"test_socket\")\n"; + prog += "\nsleep(0.6)\n"; + prog += "\ntextmsg(\"sleeping done.\")\n"; + ofs << prog; + ofs.close(); + + return modified_script_path; + } + + std::unique_ptr server_; +}; + +TEST_F(ExternalControlProgramTest, program_halts_on_timeout) +{ + vector6d_t zeros = { 0, 0, 0, 0, 0, 0 }; + g_my_robot->getUrDriver()->writeJointCommand(zeros, comm::ControlMode::MODE_IDLE, RobotReceiveTimeout::millisec(200)); + EXPECT_FALSE(server_->waitForConnectionCallback(1000)); +} + +TEST_F(ExternalControlProgramTest, stop_control_does_not_halt_program) +{ + vector6d_t zeros = { 0, 0, 0, 0, 0, 0 }; + g_my_robot->getUrDriver()->writeJointCommand(zeros, comm::ControlMode::MODE_IDLE, RobotReceiveTimeout::off()); + + // Make sure that we can stop the robot control, when robot receive timeout has been set off + g_my_robot->getUrDriver()->stopControl(); + EXPECT_TRUE(server_->waitForConnectionCallback(1000)); +} + +int main(int argc, char* argv[]) +{ + ::testing::InitGoogleTest(&argc, argv); + + for (int i = 0; i < argc; i++) + { + if (std::string(argv[i]) == "--robot_ip" && i + 1 < argc) + { + g_ROBOT_IP = argv[i + 1]; + ++i; + } + if (std::string(argv[i]) == "--headless" && i + 1 < argc) + { + std::string headless = argv[i + 1]; + g_HEADLESS = headless == "true" || headless == "1" || headless == "True" || headless == "TRUE"; + ++i; + } + } + + return RUN_ALL_TESTS(); +} diff --git a/tests/test_pipeline.cpp b/tests/test_pipeline.cpp index 8f85c8f85..c7f692234 100644 --- a/tests/test_pipeline.cpp +++ b/tests/test_pipeline.cpp @@ -31,6 +31,8 @@ #include #include +#include "test_utils.h" + #include #include #include @@ -45,8 +47,7 @@ class PipelineTest : public ::testing::Test protected: void SetUp() { - server_.reset(new comm::TCPServer(60002)); - server_->setConnectCallback(std::bind(&PipelineTest::connectionCallback, this, std::placeholders::_1)); + server_.reset(new TestableTcpServer(60002)); server_->start(); // Setup pipeline @@ -68,28 +69,7 @@ class PipelineTest : public ::testing::Test server_.reset(); } - void connectionCallback(const socket_t filedescriptor) - { - std::lock_guard lk(connect_mutex_); - client_fd_ = filedescriptor; - connect_cv_.notify_one(); - connection_callback_ = true; - } - - bool waitForConnectionCallback(int milliseconds = 100) - { - std::unique_lock lk(connect_mutex_); - if (connect_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds)) == std::cv_status::no_timeout || - connection_callback_ == true) - { - connection_callback_ = false; - return true; - } - return false; - } - - std::unique_ptr server_; - socket_t client_fd_; + std::unique_ptr server_; std::unique_ptr> stream_; std::unique_ptr parser_; @@ -138,8 +118,6 @@ class PipelineTest : public ::testing::Test private: std::condition_variable connect_cv_; std::mutex connect_mutex_; - - bool connection_callback_ = false; }; TEST_F(PipelineTest, get_product_from_stopped_pipeline) @@ -151,13 +129,13 @@ TEST_F(PipelineTest, get_product_from_stopped_pipeline) TEST_F(PipelineTest, get_product_from_running_pipeline) { - waitForConnectionCallback(); + server_->waitForConnectionCallback(); pipeline_->run(); // RTDE package with timestamp uint8_t data_package[] = { 0x00, 0x0c, 0x55, 0x01, 0x40, 0xbb, 0xbf, 0xdb, 0xa5, 0xe3, 0x53, 0xf7 }; size_t written; - server_->write(client_fd_, data_package, sizeof(data_package), written); + server_->write(data_package, sizeof(data_package), written); std::unique_ptr urpackage; std::chrono::milliseconds timeout{ 500 }; @@ -178,13 +156,13 @@ TEST_F(PipelineTest, get_product_from_running_pipeline) TEST_F(PipelineTest, stop_pipeline) { - waitForConnectionCallback(); + server_->waitForConnectionCallback(); pipeline_->run(); // RTDE package with timestamp uint8_t data_package[] = { 0x00, 0x0c, 0x55, 0x01, 0x40, 0xbb, 0xbf, 0xdb, 0xa5, 0xe3, 0x53, 0xf7 }; size_t written; - server_->write(client_fd_, data_package, sizeof(data_package), written); + server_->write(data_package, sizeof(data_package), written); std::unique_ptr urpackage; std::chrono::milliseconds timeout{ 500 }; @@ -206,13 +184,13 @@ TEST_F(PipelineTest, consumer_pipeline) pipeline_.reset( new comm::Pipeline(*producer_.get(), &consumer, "RTDE_PIPELINE", notifier_)); pipeline_->init(); - waitForConnectionCallback(); + server_->waitForConnectionCallback(); pipeline_->run(); // RTDE package with timestamp uint8_t data_package[] = { 0x00, 0x0c, 0x55, 0x01, 0x40, 0xbb, 0xbf, 0xdb, 0xa5, 0xe3, 0x53, 0xf7 }; size_t written; - server_->write(client_fd_, data_package, sizeof(data_package), written); + server_->write(data_package, sizeof(data_package), written); // Wait for data to be consumed int max_retries = 3; @@ -223,7 +201,7 @@ TEST_F(PipelineTest, consumer_pipeline) { break; } - server_->write(client_fd_, data_package, sizeof(data_package), written); + server_->write(data_package, sizeof(data_package), written); count++; } EXPECT_LT(count, max_retries); diff --git a/tests/test_producer.cpp b/tests/test_producer.cpp index 55741bdcc..32069be32 100644 --- a/tests/test_producer.cpp +++ b/tests/test_producer.cpp @@ -31,6 +31,7 @@ #include #include #include +#include "test_utils.h" #include #include @@ -44,8 +45,7 @@ class ProducerTest : public ::testing::Test protected: void SetUp() { - server_.reset(new comm::TCPServer(60002)); - server_->setConnectCallback(std::bind(&ProducerTest::connectionCallback, this, std::placeholders::_1)); + server_.reset(new TestableTcpServer(60002)); server_->start(); } @@ -55,34 +55,7 @@ class ProducerTest : public ::testing::Test server_.reset(); } - void connectionCallback(const socket_t filedescriptor) - { - std::lock_guard lk(connect_mutex_); - client_fd_ = filedescriptor; - connect_cv_.notify_one(); - connection_callback_ = true; - } - - bool waitForConnectionCallback(int milliseconds = 100) - { - std::unique_lock lk(connect_mutex_); - if (connect_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds)) == std::cv_status::no_timeout || - connection_callback_ == true) - { - connection_callback_ = false; - return true; - } - return false; - } - - std::unique_ptr server_; - socket_t client_fd_; - -private: - std::condition_variable connect_cv_; - std::mutex connect_mutex_; - - bool connection_callback_ = false; + std::unique_ptr server_; }; TEST_F(ProducerTest, get_data_package) @@ -94,13 +67,13 @@ TEST_F(ProducerTest, get_data_package) comm::URProducer producer(stream, parser); producer.setupProducer(); - waitForConnectionCallback(); + server_->waitForConnectionCallback(); producer.startProducer(); // RTDE package with timestamp uint8_t data_package[] = { 0x00, 0x0c, 0x55, 0x01, 0x40, 0xbb, 0xbf, 0xdb, 0xa5, 0xe3, 0x53, 0xf7 }; size_t written; - server_->write(client_fd_, data_package, sizeof(data_package), written); + server_->write(data_package, sizeof(data_package), written); std::vector> products; EXPECT_EQ(producer.tryGet(products), true); diff --git a/tests/test_stream.cpp b/tests/test_stream.cpp index 1cc735005..75efcbf16 100644 --- a/tests/test_stream.cpp +++ b/tests/test_stream.cpp @@ -31,10 +31,12 @@ #include #include #include +#include #include #include #include +#include "test_utils.h" #include "ur_client_library/primary/primary_package.h" using namespace urcl; @@ -44,10 +46,7 @@ class StreamTest : public ::testing::Test protected: void SetUp() { - server_.reset(new comm::TCPServer(60003)); - server_->setConnectCallback(std::bind(&StreamTest::connectionCallback, this, std::placeholders::_1)); - server_->setMessageCallback(std::bind(&StreamTest::messageCallback, this, std::placeholders::_1, - std::placeholders::_2, std::placeholders::_3)); + server_.reset(new TestableTcpServer(60003)); server_->start(); } @@ -57,62 +56,7 @@ class StreamTest : public ::testing::Test server_.reset(); } - // callback functions for the tcp server - void messageCallback([[maybe_unused]] const socket_t filedescriptor, char* buffer, size_t nbytesrecv) - { - std::lock_guard lk(message_mutex_); - read_ = nbytesrecv; - received_message_ = std::string(buffer); - message_cv_.notify_one(); - message_callback_ = true; - } - - void connectionCallback(const socket_t filedescriptor) - { - std::lock_guard lk(connect_mutex_); - client_fd_ = filedescriptor; - connect_cv_.notify_one(); - connection_callback_ = true; - } - - bool waitForMessageCallback(int milliseconds = 100) - { - std::unique_lock lk(message_mutex_); - if (message_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds)) == std::cv_status::no_timeout || - message_callback_ == true) - { - message_callback_ = false; - return true; - } - return false; - } - - bool waitForConnectionCallback(int milliseconds = 100) - { - std::unique_lock lk(connect_mutex_); - if (connect_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds)) == std::cv_status::no_timeout || - connection_callback_ == true) - { - connection_callback_ = false; - return true; - } - return false; - } - - std::unique_ptr server_; - socket_t client_fd_; - std::string received_message_; - size_t read_; - -private: - std::condition_variable message_cv_; - std::mutex message_mutex_; - - std::condition_variable connect_cv_; - std::mutex connect_mutex_; - - bool connection_callback_ = false; - bool message_callback_ = false; + std::unique_ptr server_; }; TEST_F(StreamTest, closed_stream) @@ -120,7 +64,7 @@ TEST_F(StreamTest, closed_stream) comm::URStream stream("127.0.0.1", 60003); stream.connect(); - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server_->waitForConnectionCallback()); EXPECT_FALSE(stream.closed()); stream.close(); @@ -137,7 +81,7 @@ TEST_F(StreamTest, connect_stream) EXPECT_EQ(stream.getState(), expected_state); stream.connect(); - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server_->waitForConnectionCallback()); expected_state = comm::SocketState::Connected; EXPECT_EQ(stream.getState(), expected_state); @@ -151,10 +95,10 @@ TEST_F(StreamTest, read_buffer_to_small) comm::URStream stream("127.0.0.1", 60003); stream.connect(); - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server_->waitForConnectionCallback()); size_t written; - server_->write(client_fd_, data_package, sizeof(data_package), written); + server_->write(data_package, sizeof(data_package), written); uint8_t buf[10]; size_t read = 0; @@ -172,10 +116,10 @@ TEST_F(StreamTest, read_rtde_data_package) comm::URStream stream("127.0.0.1", 60003); stream.connect(); - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server_->waitForConnectionCallback()); size_t written; - server_->write(client_fd_, data_package, sizeof(data_package), written); + server_->write(data_package, sizeof(data_package), written); uint8_t buf[4096]; size_t read = 0; @@ -285,10 +229,10 @@ TEST_F(StreamTest, read_primary_data_package) comm::URStream stream("127.0.0.1", 60003); stream.connect(); - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server_->waitForConnectionCallback()); size_t written; - server_->write(client_fd_, data_package, sizeof(data_package), written); + server_->write(data_package, sizeof(data_package), written); uint8_t buf[4096]; size_t read = 0; @@ -306,7 +250,7 @@ TEST_F(StreamTest, write_data_package) comm::URStream stream("127.0.0.1", 60003); stream.connect(); - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server_->waitForConnectionCallback()); std::string send_message = "test message"; const uint8_t* data = reinterpret_cast(send_message.c_str()); @@ -314,11 +258,14 @@ TEST_F(StreamTest, write_data_package) size_t written; stream.write(data, len, written); - EXPECT_TRUE(waitForMessageCallback()); + EXPECT_TRUE(server_->waitForMessageCallback()); + + size_t bytes_read; + std::string received_message = server_->getReceivedMessage(bytes_read); // Test that the message and the size of the message are equal - EXPECT_EQ(written, read_); - EXPECT_EQ(send_message, received_message_); + EXPECT_EQ(written, bytes_read); + EXPECT_EQ(send_message, received_message); } TEST_F(StreamTest, connect_non_connected_robot) diff --git a/tests/test_tcp_server.cpp b/tests/test_tcp_server.cpp index d85994216..08c8c4fb8 100644 --- a/tests/test_tcp_server.cpp +++ b/tests/test_tcp_server.cpp @@ -36,6 +36,7 @@ #include #include #include +#include "test_utils.h" #include #include @@ -82,98 +83,12 @@ class TCPServerTest : public ::testing::Test } }; - // callback functions - void connectionCallback(const socket_t filedescriptor) - { - std::lock_guard lk(connect_mutex_); - client_fd_ = filedescriptor; - connect_cv_.notify_one(); - connection_callback_ = true; - } - - void disconnectionCallback([[maybe_unused]] const socket_t filedescriptor) - { - std::lock_guard lk(disconnect_mutex_); - client_fd_ = INVALID_SOCKET; - disconnect_cv_.notify_one(); - disconnection_callback_ = true; - } - - void messageCallback([[maybe_unused]] const socket_t filedescriptor, char* buffer) - { - std::lock_guard lk(message_mutex_); - message_ = std::string(buffer); - message_cv_.notify_one(); - message_callback_ = true; - } - - bool waitForConnectionCallback(int milliseconds = 100) - { - std::unique_lock lk(connect_mutex_); - if (connect_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds)) == std::cv_status::no_timeout || - connection_callback_ == true) - { - connection_callback_ = false; - return true; - } - else - { - return false; - } - } - - bool waitForDisconnectionCallback(int milliseconds = 100) - { - std::unique_lock lk(disconnect_mutex_); - if (disconnect_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds)) == std::cv_status::no_timeout || - disconnection_callback_ == true) - { - disconnection_callback_ = false; - return true; - } - else - { - return false; - } - } - - bool waitForMessageCallback(int milliseconds = 100) - { - std::unique_lock lk(message_mutex_); - if (message_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds)) == std::cv_status::no_timeout || - message_callback_ == true) - { - message_callback_ = false; - return true; - } - else - { - return false; - } - } - int port_ = 50001; - std::string message_ = ""; - socket_t client_fd_ = INVALID_SOCKET; - -private: - std::condition_variable connect_cv_; - std::mutex connect_mutex_; - - std::condition_variable disconnect_cv_; - std::mutex disconnect_mutex_; - - std::condition_variable message_cv_; - std::mutex message_mutex_; - - bool connection_callback_ = false; - bool disconnection_callback_ = false; - bool message_callback_ = false; }; TEST_F(TCPServerTest, socket_creation) { - comm::TCPServer server(port_); + TestableTcpServer server(port_, false); // do not register callbacks // Shouldn't be able to create antoher server on same port EXPECT_THROW(comm::TCPServer server2(port_, 1, std::chrono::milliseconds(1)), std::system_error); @@ -186,43 +101,31 @@ TEST_F(TCPServerTest, socket_creation) // We should also be able to send message and disconnect. We wait to be absolutely sure no exception is thrown EXPECT_NO_THROW(client.send("message\n")); - EXPECT_NO_THROW(waitForMessageCallback()); + EXPECT_NO_THROW(server.waitForMessageCallback()); EXPECT_NO_THROW(client.close()); - EXPECT_NO_THROW(waitForDisconnectionCallback()); + EXPECT_NO_THROW(server.waitForDisconnectionCallback()); } TEST_F(TCPServerTest, callback_functions) { - comm::TCPServer server(port_); - server.setMessageCallback(std::bind(&TCPServerTest_callback_functions_Test::messageCallback, this, - std::placeholders::_1, std::placeholders::_2)); - server.setConnectCallback( - std::bind(&TCPServerTest_callback_functions_Test::connectionCallback, this, std::placeholders::_1)); - server.setDisconnectCallback( - std::bind(&TCPServerTest_callback_functions_Test::disconnectionCallback, this, std::placeholders::_1)); + TestableTcpServer server(port_); server.start(); // Check that the appropriate callback functions are called Client client(port_); - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server.waitForConnectionCallback()); client.send("message\n"); - EXPECT_TRUE(waitForMessageCallback()); + EXPECT_TRUE(server.waitForMessageCallback()); client.close(); - EXPECT_TRUE(waitForDisconnectionCallback()); + EXPECT_TRUE(server.waitForDisconnectionCallback()); } TEST_F(TCPServerTest, many_clients_allowed) { - comm::TCPServer server(port_); - server.setMessageCallback(std::bind(&TCPServerTest_many_clients_allowed_Test::messageCallback, this, - std::placeholders::_1, std::placeholders::_2)); - server.setConnectCallback( - std::bind(&TCPServerTest_many_clients_allowed_Test::connectionCallback, this, std::placeholders::_1)); - server.setDisconnectCallback( - std::bind(&TCPServerTest_many_clients_allowed_Test::disconnectionCallback, this, std::placeholders::_1)); + TestableTcpServer server(port_); server.start(); #ifdef _WIN32 @@ -239,67 +142,49 @@ TEST_F(TCPServerTest, many_clients_allowed) for (unsigned int i = 0; i < num_clients; ++i) { clients.push_back(std::make_unique(port_)); - ASSERT_TRUE(waitForConnectionCallback()); + ASSERT_TRUE(server.waitForConnectionCallback()); } } TEST_F(TCPServerTest, max_clients_allowed) { - comm::TCPServer server(port_); - server.setMessageCallback(std::bind(&TCPServerTest_max_clients_allowed_Test::messageCallback, this, - std::placeholders::_1, std::placeholders::_2)); - server.setConnectCallback( - std::bind(&TCPServerTest_max_clients_allowed_Test::connectionCallback, this, std::placeholders::_1)); - server.setDisconnectCallback( - std::bind(&TCPServerTest_max_clients_allowed_Test::disconnectionCallback, this, std::placeholders::_1)); + TestableTcpServer server(port_); server.start(); server.setMaxClientsAllowed(1); // Test that only one client can connect Client client1(port_); - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server.waitForConnectionCallback()); Client client2(port_); - EXPECT_FALSE(waitForConnectionCallback()); + EXPECT_FALSE(server.waitForConnectionCallback()); } TEST_F(TCPServerTest, message_transmission) { - comm::TCPServer server(port_); - server.setMessageCallback(std::bind(&TCPServerTest_message_transmission_Test::messageCallback, this, - std::placeholders::_1, std::placeholders::_2)); - server.setConnectCallback( - std::bind(&TCPServerTest_message_transmission_Test::connectionCallback, this, std::placeholders::_1)); - server.setDisconnectCallback( - std::bind(&TCPServerTest_message_transmission_Test::disconnectionCallback, this, std::placeholders::_1)); + TestableTcpServer server(port_); server.start(); Client client(port_); - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server.waitForConnectionCallback()); // Test that messages are transmitted corectly between client and server std::string message = "test message\n"; client.send(message); - EXPECT_TRUE(waitForMessageCallback()); - EXPECT_EQ(message, message_); + EXPECT_TRUE(server.waitForMessageCallback()); + EXPECT_EQ(message, server.getReceivedMessage()); size_t len = message.size(); const uint8_t* data = reinterpret_cast(message.c_str()); size_t written; - ASSERT_TRUE(server.write(client_fd_, data, len, written)); + ASSERT_TRUE(server.write(data, len, written)); EXPECT_EQ(client.recv(), message); } TEST_F(TCPServerTest, client_connections) { - comm::TCPServer server(port_); - server.setMessageCallback(std::bind(&TCPServerTest_client_connections_Test::messageCallback, this, - std::placeholders::_1, std::placeholders::_2)); - server.setConnectCallback( - std::bind(&TCPServerTest_client_connections_Test::connectionCallback, this, std::placeholders::_1)); - server.setDisconnectCallback( - std::bind(&TCPServerTest_client_connections_Test::disconnectionCallback, this, std::placeholders::_1)); + TestableTcpServer server(port_); server.start(); std::string message = "text message\n"; @@ -309,36 +194,37 @@ TEST_F(TCPServerTest, client_connections) // Test that we can connect multiple clients Client client1(port_); - EXPECT_TRUE(waitForConnectionCallback()); - socket_t client1_fd = client_fd_; + EXPECT_TRUE(server.waitForConnectionCallback()); Client client2(port_); - EXPECT_TRUE(waitForConnectionCallback()); - socket_t client2_fd = client_fd_; + EXPECT_TRUE(server.waitForConnectionCallback()); Client client3(port_); - EXPECT_TRUE(waitForConnectionCallback()); - socket_t client3_fd = client_fd_; + EXPECT_TRUE(server.waitForConnectionCallback()); + + auto client_fds = server.getClientFDs(); // Test that the correct clients are disconnected on the server side. client1.close(); - EXPECT_TRUE(waitForDisconnectionCallback()); + EXPECT_TRUE(server.waitForDisconnectionCallback()); - EXPECT_FALSE(server.write(client1_fd, data, len, written)); - EXPECT_TRUE(server.write(client2_fd, data, len, written)); - EXPECT_TRUE(server.write(client3_fd, data, len, written)); + auto tcp_server = dynamic_cast(&server); + + EXPECT_FALSE(tcp_server->write(client_fds[0], data, len, written)); + EXPECT_TRUE(tcp_server->write(client_fds[1], data, len, written)); + EXPECT_TRUE(tcp_server->write(client_fds[2], data, len, written)); client2.close(); - EXPECT_TRUE(waitForDisconnectionCallback()); - EXPECT_FALSE(server.write(client1_fd, data, len, written)); - EXPECT_FALSE(server.write(client2_fd, data, len, written)); - EXPECT_TRUE(server.write(client3_fd, data, len, written)); + EXPECT_TRUE(server.waitForDisconnectionCallback()); + EXPECT_FALSE(tcp_server->write(client_fds[0], data, len, written)); + EXPECT_FALSE(tcp_server->write(client_fds[1], data, len, written)); + EXPECT_TRUE(tcp_server->write(client_fds[2], data, len, written)); client3.close(); - EXPECT_TRUE(waitForDisconnectionCallback()); - EXPECT_FALSE(server.write(client1_fd, data, len, written)); - EXPECT_FALSE(server.write(client2_fd, data, len, written)); - EXPECT_FALSE(server.write(client3_fd, data, len, written)); + EXPECT_TRUE(server.waitForDisconnectionCallback()); + EXPECT_FALSE(tcp_server->write(client_fds[0], data, len, written)); + EXPECT_FALSE(tcp_server->write(client_fds[1], data, len, written)); + EXPECT_FALSE(tcp_server->write(client_fds[2], data, len, written)); } TEST_F(TCPServerTest, check_address_already_in_use) { @@ -388,15 +274,11 @@ TEST_F(TCPServerTest, check_shutting_down_server_while_listening) TEST_F(TCPServerTest, double_shutdown) { - comm::TCPServer server(port_); - server.setConnectCallback( - std::bind(&TCPServerTest_double_shutdown_Test::connectionCallback, this, std::placeholders::_1)); - server.setDisconnectCallback( - std::bind(&TCPServerTest_double_shutdown_Test::disconnectionCallback, this, std::placeholders::_1)); + TestableTcpServer server(port_); server.start(); Client client(port_); - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server.waitForConnectionCallback()); EXPECT_NO_THROW(server.shutdown()); EXPECT_NO_THROW(server.shutdown()); @@ -404,14 +286,11 @@ TEST_F(TCPServerTest, double_shutdown) TEST_F(TCPServerTest, concurrent_writes_same_client) { - comm::TCPServer server(0); - server.setConnectCallback([this](const socket_t fd) { connectionCallback(fd); }); - server.setDisconnectCallback([this](const socket_t fd) { disconnectionCallback(fd); }); + TestableTcpServer server(0); server.start(); Client client(server.getPort()); - ASSERT_TRUE(waitForConnectionCallback()); - const socket_t fd = client_fd_; + ASSERT_TRUE(server.waitForConnectionCallback()); const std::string message = "test data\n"; const auto* data = reinterpret_cast(message.c_str()); @@ -424,11 +303,11 @@ TEST_F(TCPServerTest, concurrent_writes_same_client) for (int i = 0; i < num_threads; ++i) { - writers.emplace_back([&server, fd, data, len, &success_count]() { + writers.emplace_back([&server, data, len, &success_count]() { for (int j = 0; j < writes_per_thread; ++j) { size_t written; - if (server.write(fd, data, len, written)) + if (server.write(data, len, written)) { ++success_count; } @@ -446,14 +325,11 @@ TEST_F(TCPServerTest, concurrent_writes_same_client) TEST_F(TCPServerTest, write_during_client_disconnect) { - comm::TCPServer server(0); - server.setConnectCallback([this](const socket_t fd) { connectionCallback(fd); }); - server.setDisconnectCallback([this](const socket_t fd) { disconnectionCallback(fd); }); + TestableTcpServer server(0); server.start(); Client client(server.getPort()); - ASSERT_TRUE(waitForConnectionCallback()); - const socket_t fd = client_fd_; + ASSERT_TRUE(server.waitForConnectionCallback()); const std::string message = "test data\n"; const auto* data = reinterpret_cast(message.c_str()); @@ -461,17 +337,17 @@ TEST_F(TCPServerTest, write_during_client_disconnect) std::atomic stop{ false }; - std::thread writer([&server, fd, data, len, &stop]() { + std::thread writer([&server, data, len, &stop]() { while (!stop.load()) { size_t written; - server.write(fd, data, len, written); + server.write(data, len, written); } }); std::this_thread::sleep_for(std::chrono::milliseconds(10)); client.close(); - ASSERT_TRUE(waitForDisconnectionCallback()); + ASSERT_TRUE(server.waitForDisconnectionCallback()); std::this_thread::sleep_for(std::chrono::milliseconds(10)); stop.store(true); @@ -595,14 +471,11 @@ TEST_F(TCPServerTest, concurrent_writes_multiple_clients) TEST_F(TCPServerTest, shutdown_during_active_writes) { - comm::TCPServer server(0); - server.setConnectCallback([this](const socket_t fd) { connectionCallback(fd); }); - server.setDisconnectCallback([this](const socket_t fd) { disconnectionCallback(fd); }); + TestableTcpServer server(0); server.start(); Client client(server.getPort()); - ASSERT_TRUE(waitForConnectionCallback()); - const socket_t fd = client_fd_; + ASSERT_TRUE(server.waitForConnectionCallback()); const std::string message = "test data\n"; const auto* data = reinterpret_cast(message.c_str()); @@ -610,11 +483,11 @@ TEST_F(TCPServerTest, shutdown_during_active_writes) std::atomic stop{ false }; - std::thread writer([&server, fd, data, len, &stop]() { + std::thread writer([&server, data, len, &stop]() { while (!stop.load()) { size_t written; - server.write(fd, data, len, written); + server.write(data, len, written); } }); diff --git a/tests/test_tcp_socket.cpp b/tests/test_tcp_socket.cpp index f3a6fe9bf..d64f12ea5 100644 --- a/tests/test_tcp_socket.cpp +++ b/tests/test_tcp_socket.cpp @@ -32,6 +32,7 @@ #include #include #include +#include "test_utils.h" // This file adds a test for a deprecated function. To avoid a compiler warning in CI (where we want // to treat warnings as errors) we suppress the warning inside this file. @@ -51,10 +52,7 @@ class TCPSocketTest : public ::testing::Test protected: void SetUp() { - server_.reset(new comm::TCPServer(60001)); - server_->setConnectCallback(std::bind(&TCPSocketTest::connectionCallback, this, std::placeholders::_1)); - server_->setMessageCallback( - std::bind(&TCPSocketTest::messageCallback, this, std::placeholders::_1, std::placeholders::_2)); + server_.reset(new TestableTcpServer(60001)); server_->start(); client_.reset(new Client(60001)); @@ -66,47 +64,6 @@ class TCPSocketTest : public ::testing::Test client_.reset(); } - // callback functions for the tcp server - void messageCallback([[maybe_unused]] const socket_t filedescriptor, char* buffer) - { - std::lock_guard lk(message_mutex_); - received_message_ = std::string(buffer); - message_cv_.notify_one(); - message_callback_ = true; - } - - void connectionCallback(const socket_t filedescriptor) - { - std::lock_guard lk(connect_mutex_); - client_fd_ = filedescriptor; - connect_cv_.notify_one(); - connection_callback_ = true; - } - - bool waitForMessageCallback(int milliseconds = 100) - { - std::unique_lock lk(message_mutex_); - if (message_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds)) == std::cv_status::no_timeout || - message_callback_ == true) - { - message_callback_ = false; - return true; - } - return false; - } - - bool waitForConnectionCallback(int milliseconds = 100) - { - std::unique_lock lk(connect_mutex_); - if (connect_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds)) == std::cv_status::no_timeout || - connection_callback_ == true) - { - connection_callback_ = false; - return true; - } - return false; - } - class Client : public comm::TCPSocket { public: @@ -162,21 +119,8 @@ class TCPSocketTest : public ::testing::Test } }; - std::string received_message_; - socket_t client_fd_; - - std::unique_ptr server_; + std::unique_ptr server_; std::unique_ptr client_; - -private: - std::condition_variable message_cv_; - std::mutex message_mutex_; - - std::condition_variable connect_cv_; - std::mutex connect_mutex_; - - bool connection_callback_ = false; - bool message_callback_ = false; }; TEST_F(TCPSocketTest, socket_state) @@ -218,7 +162,7 @@ TEST_F(TCPSocketTest, setup_client_before_server) EXPECT_EQ(toUnderlying(expected_state), toUnderlying(actual_state)); - server_.reset(new comm::TCPServer(60001)); + server_.reset(new TestableTcpServer(60001)); server_->start(); // Test that client goes into connected state after the server has been started @@ -272,8 +216,8 @@ TEST_F(TCPSocketTest, write_on_connected_socket) size_t written; client_->write(data, len, written); - EXPECT_TRUE(waitForMessageCallback()); - EXPECT_EQ(message, received_message_); + EXPECT_TRUE(server_->waitForMessageCallback()); + EXPECT_EQ(message, server_->getReceivedMessage()); } TEST_F(TCPSocketTest, read_on_connected_socket) @@ -281,13 +225,13 @@ TEST_F(TCPSocketTest, read_on_connected_socket) client_->setup(); // Make sure the client has connected to the server, before writing to the client - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server_->waitForConnectionCallback()); std::string send_message = "test message"; size_t len = send_message.size(); const uint8_t* data = reinterpret_cast(send_message.c_str()); size_t written; - server_->write(client_fd_, data, len, written); + server_->write(data, len, written); std::stringstream ss; char characters; @@ -367,13 +311,13 @@ TEST_F(TCPSocketTest, test_read_on_socket_abruptly_closed) client_->setup(); // Make sure the client has connected to the server, before writing to the client - EXPECT_TRUE(waitForConnectionCallback()); + EXPECT_TRUE(server_->waitForConnectionCallback()); std::string send_message = "test message"; size_t len = send_message.size(); const uint8_t* data = reinterpret_cast(send_message.c_str()); size_t written; - server_->write(client_fd_, data, len, written); + server_->write(data, len, written); // Simulate socket failure ur_close(client_->getSocketFD()); diff --git a/tests/test_utils.h b/tests/test_utils.h index 7db5ca04d..31fc11022 100644 --- a/tests/test_utils.h +++ b/tests/test_utils.h @@ -31,6 +31,7 @@ #pragma once #include +#include "ur_client_library/comm/tcp_server.h" bool robotVersionLessThan(const std::string& robot_ip, const std::string& robot_version) { @@ -40,3 +41,141 @@ bool robotVersionLessThan(const std::string& robot_ip, const std::string& robot_ auto version_information = primary_client.getRobotVersion(); return *version_information < urcl::VersionInformation::fromString(robot_version); } + +class TestableTcpServer : public urcl::comm::TCPServer +{ +public: + TestableTcpServer(const int port, const bool register_callbacks = true) : TCPServer(port) + { + if (register_callbacks) + { + this->setConnectCallback(std::bind(&TestableTcpServer::connectionCallback, this, std::placeholders::_1)); + this->setMessageCallback(std::bind(&TestableTcpServer::messageCallback, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3)); + this->setDisconnectCallback(std::bind(&TestableTcpServer::disconnectionCallback, this, std::placeholders::_1)); + } + } + + ~TestableTcpServer() + { + // unregister callbacks to avoid any callback being triggered after the server is destroyed, + // which would cause the tests to fail due to accessing already destroyed objects. + setConnectCallback([](const socket_t) {}); + setMessageCallback([](const socket_t, char*, int) {}); + setDisconnectCallback([](const socket_t) {}); + } + + void connectionCallback(const socket_t filedescriptor) + { + std::lock_guard lk(connect_mutex_); + client_fds_.push_back(filedescriptor); + connect_cv_.notify_one(); + connection_callback_ = true; + } + + void messageCallback([[maybe_unused]] const socket_t filedescriptor, char* buffer, int nbytesrecv) + { + std::lock_guard lk(message_mutex_); + received_message_ = std::string(buffer); + read_ = nbytesrecv; + message_cv_.notify_one(); + message_callback_ = true; + } + + void disconnectionCallback(const socket_t filedescriptor) + { + std::lock_guard lk(connect_mutex_); + for (size_t i = 0; i < client_fds_.size(); ++i) + { + if (client_fds_[i] == filedescriptor) + { + client_fds_.erase(client_fds_.begin() + i); + break; + } + } + disconnect_cv_.notify_one(); + disconnection_callback_ = true; + } + + bool waitForConnectionCallback(int milliseconds = 100) + { + std::unique_lock lk(connect_mutex_); + if (connect_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds), + [this]() { return connection_callback_ == true; })) + { + connection_callback_ = false; + return true; + } + return false; + } + + bool waitForMessageCallback(int milliseconds = 100) + { + std::unique_lock lk(message_mutex_); + if (message_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds), + [this]() { return message_callback_ == true; })) + { + message_callback_ = false; + return true; + } + return false; + } + + bool waitForDisconnectionCallback(int milliseconds = 100) + { + std::unique_lock lk(connect_mutex_); + if (disconnect_cv_.wait_for(lk, std::chrono::milliseconds(milliseconds), + [this]() { return disconnection_callback_ == true; })) + { + disconnection_callback_ = false; + return true; + } + else + { + return false; + } + } + + bool write(const uint8_t* buf, const size_t buf_len, size_t& written, const size_t client_index = 0) + { + std::unique_lock lk(connect_mutex_); + if (client_fds_.empty() || client_index >= client_fds_.size()) + { + return false; + } + return TCPServer::write(client_fds_[client_index], buf, buf_len, written); + } + + std::string getReceivedMessage() + { + size_t bytes_read; + return getReceivedMessage(bytes_read); + } + + std::string getReceivedMessage(size_t& bytes_read) + { + std::lock_guard lk(message_mutex_); + bytes_read = read_; + return received_message_; + } + + std::vector getClientFDs() + { + std::lock_guard lk(connect_mutex_); + return client_fds_; + } + +private: + std::vector client_fds_; + std::condition_variable connect_cv_; + std::condition_variable message_cv_; + std::condition_variable disconnect_cv_; + std::mutex connect_mutex_; + std::mutex message_mutex_; + std::atomic connection_callback_ = false; + std::atomic message_callback_ = false; + std::atomic disconnection_callback_ = false; + + std::string received_message_; + size_t read_ = 0; +};