diff --git a/http_parser.c b/http_parser.c index dcb28f9..91c217d 100644 --- a/http_parser.c +++ b/http_parser.c @@ -1438,7 +1438,7 @@ size_t http_parser_execute (http_parser *parser, /* Exit, the rest of the connect is in a different protocol. */ if (parser->upgrade) { CALLBACK2(message_complete); - return (p - data); + return (p - data) + 1; } if (parser->flags & F_SKIPBODY) { diff --git a/test.c b/test.c index b1d2e7b..c452303 100644 --- a/test.c +++ b/test.c @@ -55,7 +55,7 @@ struct message { char headers [MAX_HEADERS][2][MAX_ELEMENT_SIZE]; int should_keep_alive; - int upgrade; + const char *upgrade; // upgraded body unsigned short http_major; unsigned short http_minor; @@ -473,6 +473,7 @@ const struct message requests[] = "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" "Origin: http://example.com\r\n" "\r\n" + "Hot diggity dogg" ,.should_keep_alive= TRUE ,.message_complete_on_eof= FALSE ,.http_major= 1 @@ -483,7 +484,7 @@ const struct message requests[] = ,.request_path= "/demo" ,.request_url= "/demo" ,.num_headers= 7 - ,.upgrade=1 + ,.upgrade="Hot diggity dogg" ,.headers= { { "Host", "example.com" } , { "Connection", "Upgrade" } , { "Sec-WebSocket-Key2", "12998 5 Y3 1 .P00" } @@ -502,6 +503,8 @@ const struct message requests[] = "User-agent: Mozilla/1.1N\r\n" "Proxy-authorization: basic aGVsbG86d29ybGQ=\r\n" "\r\n" + "some data\r\n" + "and yet even more data" ,.should_keep_alive= FALSE ,.message_complete_on_eof= FALSE ,.http_major= 1 @@ -512,7 +515,7 @@ const struct message requests[] = ,.request_path= "" ,.request_url= "0-home0.netscape.com:443" ,.num_headers= 2 - ,.upgrade=1 + ,.upgrade="some data\r\nand yet even more data" ,.headers= { { "User-agent", "Mozilla/1.1N" } , { "Proxy-authorization", "basic aGVsbG86d29ybGQ=" } } @@ -707,7 +710,7 @@ const struct message requests[] = ,.request_path= "" ,.request_url= "home_0.netscape.com:443" ,.num_headers= 2 - ,.upgrade=1 + ,.upgrade="" ,.headers= { { "User-agent", "Mozilla/1.1N" } , { "Proxy-authorization", "basic aGVsbG86d29ybGQ=" } } @@ -735,7 +738,6 @@ const struct message requests[] = ,.request_path= "/file.txt" ,.request_url= "/file.txt" ,.num_headers= 4 - ,.upgrade=0 ,.headers= { { "Host", "www.example.com" } , { "Content-Type", "application/example" } , { "If-Match", "\"e0023aa4e\"" } @@ -1337,7 +1339,13 @@ check_str_eq (const struct message *m, const char *prop, const char *expected, const char *found) { - if (0 != strcmp(expected, found)) { + if ((expected == NULL) != (found == NULL)) { + printf("\n*** Error: %s in '%s' ***\n\n", prop, m->name); + printf("expected %s\n", (expected == NULL) ? "NULL" : expected); + printf(" found %s\n", (found == NULL) ? "NULL" : found); + return 0; + } + if (expected != NULL && 0 != strcmp(expected, found)) { printf("\n*** Error: %s in '%s' ***\n\n", prop, m->name); printf("expected '%s'\n", expected); printf(" found '%s'\n", found); @@ -1410,9 +1418,75 @@ message_eq (int index, const struct message *expected) if (!r) return 0; } + MESSAGE_CHECK_STR_EQ(expected, m, upgrade); + return 1; } +/* Given a sequence of varargs messages, return the number of them that the + * parser should successfully parse, taking into account that upgraded + * messages prevent all subsequent messages from being parsed. + */ +size_t +count_parsed_messages(const size_t nmsgs, ...) { + size_t i; + va_list ap; + + va_start(ap, nmsgs); + + for (i = 0; i < nmsgs; i++) { + struct message *m = va_arg(ap, struct message *); + + if (m->upgrade) { + va_end(ap); + return i + 1; + } + } + + va_end(ap); + return nmsgs; +} + +/* Given a sequence of bytes and the number of these that we were able to + * parse, verify that upgrade bodies are correct. + */ +void +upgrade_message_fix(char *body, const size_t nread, const size_t nmsgs, ...) { + va_list ap; + size_t i; + size_t off = 0; + + va_start(ap, nmsgs); + + for (i = 0; i < nmsgs; i++) { + struct message *m = va_arg(ap, struct message *); + + off += strlen(m->raw); + + if (m->upgrade) { + off -= strlen(m->upgrade); + + /* Check the portion of the response after its specified upgrade */ + if (!check_str_eq(m, "upgrade", body + off, body + nread)) { + exit(1); + } + + /* Fix up the response so that message_eq() will verify the beginning + * of the upgrade */ + *(body + nread + strlen(m->upgrade)) = '\0'; + messages[num_messages -1 ].upgrade = body + nread; + + va_end(ap); + return; + } + } + + va_end(ap); + printf("\n\n*** Error: expected a message with upgrade ***\n"); + + exit(1); +} + static void print_error (const char *raw, size_t error_location) { @@ -1471,7 +1545,10 @@ test_message (const struct message *message) if (msg1len) { read = parse(msg1, msg1len); - if (message->upgrade && parser->upgrade) goto test; + if (message->upgrade && parser->upgrade) { + messages[num_messages - 1].upgrade = msg1 + read; + goto test; + } if (read != msg1len) { print_error(msg1, read); @@ -1482,7 +1559,10 @@ test_message (const struct message *message) read = parse(msg2, msg2len); - if (message->upgrade && parser->upgrade) goto test; + if (message->upgrade && parser->upgrade) { + messages[num_messages - 1].upgrade = msg2 + read; + goto test; + } if (read != msg2len) { print_error(msg2, read); @@ -1491,8 +1571,6 @@ test_message (const struct message *message) read = parse(NULL, 0); - if (message->upgrade && parser->upgrade) goto test; - if (read != 0) { print_error(message->raw, read); exit(1); @@ -1627,12 +1705,7 @@ test_no_overflow_long_body (int req, size_t length) void test_multiple3 (const struct message *r1, const struct message *r2, const struct message *r3) { - int message_count = 1; - if (!r1->upgrade) { - message_count++; - if (!r2->upgrade) message_count++; - } - int has_upgrade = (message_count < 3 || r3->upgrade); + int message_count = count_parsed_messages(3, r1, r2, r3); char total[ strlen(r1->raw) + strlen(r2->raw) @@ -1651,7 +1724,10 @@ test_multiple3 (const struct message *r1, const struct message *r2, const struct read = parse(total, strlen(total)); - if (has_upgrade && parser->upgrade) goto test; + if (parser->upgrade) { + upgrade_message_fix(total, read, 3, r1, r2, r3); + goto test; + } if (read != strlen(total)) { print_error(total, read); @@ -1660,8 +1736,6 @@ test_multiple3 (const struct message *r1, const struct message *r2, const struct read = parse(NULL, 0); - if (has_upgrade && parser->upgrade) goto test; - if (read != 0) { print_error(total, read); exit(1); @@ -1675,12 +1749,8 @@ test: } if (!message_eq(0, r1)) exit(1); - if (message_count > 1) { - if (!message_eq(1, r2)) exit(1); - if (message_count > 2) { - if (!message_eq(2, r3)) exit(1); - } - } + if (message_count > 1 && !message_eq(1, r2)) exit(1); + if (message_count > 2 && !message_eq(2, r3)) exit(1); parser_free(); } @@ -1709,6 +1779,7 @@ test_scan (const struct message *r1, const struct message *r2, const struct mess int ops = 0 ; size_t buf1_len, buf2_len, buf3_len; + int message_count = count_parsed_messages(3, r1, r2, r3); int i,j,type_both; for (type_both = 0; type_both < 2; type_both ++ ) { @@ -1737,27 +1808,27 @@ test_scan (const struct message *r1, const struct message *r2, const struct mess read = parse(buf1, buf1_len); - if (r3->upgrade && parser->upgrade) goto test; + if (parser->upgrade) goto test; if (read != buf1_len) { print_error(buf1, read); goto error; } - read = parse(buf2, buf2_len); + read += parse(buf2, buf2_len); - if (r3->upgrade && parser->upgrade) goto test; + if (parser->upgrade) goto test; - if (read != buf2_len) { + if (read != buf1_len + buf2_len) { print_error(buf2, read); goto error; } - read = parse(buf3, buf3_len); + read += parse(buf3, buf3_len); - if (r3->upgrade && parser->upgrade) goto test; + if (parser->upgrade) goto test; - if (read != buf3_len) { + if (read != buf1_len + buf2_len + buf3_len) { print_error(buf3, read); goto error; } @@ -1765,9 +1836,13 @@ test_scan (const struct message *r1, const struct message *r2, const struct mess parse(NULL, 0); test: + if (parser->upgrade) { + upgrade_message_fix(total, read, 3, r1, r2, r3); + } - if (3 != num_messages) { - fprintf(stderr, "\n\nParser didn't see 3 messages only %d\n", num_messages); + if (message_count != num_messages) { + fprintf(stderr, "\n\nParser didn't see %d messages only %d\n", + message_count, num_messages); goto error; } @@ -1776,12 +1851,12 @@ test: goto error; } - if (!message_eq(1, r2)) { + if (message_count > 1 && !message_eq(1, r2)) { fprintf(stderr, "\n\nError matching messages[1] in test_scan.\n"); goto error; } - if (!message_eq(2, r3)) { + if (message_count > 2 && !message_eq(2, r3)) { fprintf(stderr, "\n\nError matching messages[2] in test_scan.\n"); goto error; }