diff --git a/test/test_rcon.c b/test/test_rcon.c index 307d023..e4f934c 100644 --- a/test/test_rcon.c +++ b/test/test_rcon.c @@ -16,6 +16,30 @@ static void assert_packet_equals(const rcon_packet_t *packet, (int32_t)(sizeof(int32_t) * 2 + strlen(expected_body) + 2)); } +static void send_raw_packet(int sockfd, int32_t id, int32_t type, + const char *body) { + size_t body_len = strlen(body); + size_t packet_size = sizeof(int32_t) * 3 + body_len + 2; + uint8_t packet[sizeof(int32_t) * 3 + 4096 + 2] = {0}; + + assert_true(packet_size <= sizeof(packet)); + + int32_t size = htole32(packet_size - sizeof(int32_t)); + memcpy(packet, &size, sizeof(int32_t)); + + int32_t id_le = htole32(id); + memcpy(packet + 4, &id_le, sizeof(int32_t)); + + int32_t type_le = htole32(type); + memcpy(packet + 8, &type_le, sizeof(int32_t)); + + memcpy(packet + 12, body, body_len); + packet[12 + body_len] = 0; + packet[12 + body_len + 1] = 0; + + assert_int_equal((int)send(sockfd, packet, packet_size, 0), (int)packet_size); +} + typedef struct { int client; int server; @@ -33,6 +57,8 @@ static void close_socketpair(socketpair_t *sockets) { close(sockets->server); } +/// Serialization and deserialization /// + static void test_serialization_roundtrip_command_packet(void **state) { (void)state; @@ -118,6 +144,84 @@ static void test_recv_packet_errors_on_unexpected_close(void **state) { close_socketpair(&sockets); } +/// End of serialization and deserialization /// + +/// Authentication logic /// + +static void receive_auth_response(int sockfd, const char *expected_password) { + rcon_packet_t packet = {0}; + + assert_int_equal(recv_packet(sockfd, &packet), 0); + assert_packet_equals(&packet, 1, RCON_SERVERDATA_AUTH, expected_password); +} + +static void test_authenticate_successful_flow(void **state) { + (void)state; + + socketpair_t sockets = create_socketpair(); + + const char *password = "hunter2"; + + send_raw_packet(sockets.server, 1, RCON_SERVERDATA_RESPONSE_VALUE, ""); + send_raw_packet(sockets.server, 1, RCON_SERVERDATA_AUTH_RESPONSE, ""); + + assert_int_equal(rcon_authenticate(sockets.client, password), 0); + receive_auth_response(sockets.server, password); + + close(sockets.client); + close(sockets.server); +} + +static void test_authenticate_invalid_password(void **state) { + (void)state; + + socketpair_t sockets = create_socketpair(); + + const char *password = "wrong-password"; + + send_raw_packet(sockets.server, -1, RCON_SERVERDATA_RESPONSE_VALUE, ""); + + assert_int_equal(rcon_authenticate(sockets.client, password), -1); + receive_auth_response(sockets.server, password); + + close(sockets.client); + close(sockets.server); +} + +static void test_authenticate_unexpected_packet_type(void **state) { + (void)state; + + socketpair_t sockets = create_socketpair(); + + const char *password = "hunter2"; + + send_raw_packet(sockets.server, 1, RCON_SERVERDATA_EXECCOMMAND, ""); + + assert_int_equal(rcon_authenticate(sockets.client, password), -1); + receive_auth_response(sockets.server, password); + + close(sockets.client); + close(sockets.server); +} + +static void test_authenticate_mismatched_response_id(void **state) { + (void)state; + + socketpair_t sockets = create_socketpair(); + + const char *password = "hunter2"; + + send_raw_packet(sockets.server, 2, RCON_SERVERDATA_RESPONSE_VALUE, ""); + + assert_int_equal(rcon_authenticate(sockets.client, password), -1); + receive_auth_response(sockets.server, password); + + close(sockets.client); + close(sockets.server); +} + +/// End of authentication logic /// + int main(void) { const struct CMUnitTest tests[] = { cmocka_unit_test(test_serialization_roundtrip_command_packet), @@ -125,6 +229,10 @@ int main(void) { cmocka_unit_test(test_serialization_roundtrip_empty_body_packet), cmocka_unit_test(test_serialization_roundtrip_max_supported_body), cmocka_unit_test(test_recv_packet_errors_on_unexpected_close), + cmocka_unit_test(test_authenticate_successful_flow), + cmocka_unit_test(test_authenticate_invalid_password), + cmocka_unit_test(test_authenticate_unexpected_packet_type), + cmocka_unit_test(test_authenticate_mismatched_response_id), }; return cmocka_run_group_tests(tests, NULL, NULL);