diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 29b156fd11..75cea86104 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -52,11 +52,12 @@ typedef struct { typedef struct { char *path; - char *buffer; char *sub_protocol; char *user_agent; char *headers; char *auth; + char *buffer; /*!< Initial HTTP connection buffer, which may include data beyond the handshake headers, such as the next WebSocket packet*/ + size_t buffer_len; /*!< The buffer length */ int http_status_code; bool propagate_control_frames; ws_transport_frame_state_t frame_state; @@ -101,6 +102,35 @@ static esp_transport_handle_t ws_get_payload_transport_handle(esp_transport_hand return ws->parent; } +static int esp_transport_read_internal(transport_ws_t *ws, char *buffer, int len, int timeout_ms) +{ + // No buffered data to read from, directly attempt to read from the transport. + if (ws->buffer_len == 0) { + return esp_transport_read(ws->parent, buffer, len, timeout_ms); + } + + // At this point, buffer_len is guaranteed to be > 0. + int to_read = (ws->buffer_len >= len) ? len : ws->buffer_len; + + // Copy the available or requested data to the buffer. + memcpy(buffer, ws->buffer, to_read); + + if (to_read < ws->buffer_len) { + // Shift remaining data if not all was read. + memmove(ws->buffer, ws->buffer + to_read, ws->buffer_len - to_read); + ws->buffer_len -= to_read; + } else { + // All buffer data was consumed. +#ifdef CONFIG_WS_DYNAMIC_BUFFER + free(ws->buffer); + ws->buffer = NULL; +#endif + ws->buffer_len = 0; + } + + return to_read; +} + static char *trimwhitespace(const char *str) { char *end; @@ -164,6 +194,8 @@ static char *get_http_header(const char *buffer, const char *key) static int ws_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) { transport_ws_t *ws = esp_transport_get_context_data(t); + const char delimiter[] = "\r\n\r\n"; + if (esp_transport_connect(ws->parent, host, port, timeout_ms) < 0) { ESP_LOGE(TAG, "Error connecting to host %s:%d", host, port); return -1; @@ -256,9 +288,12 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int return -1; } header_len += len; - ws->buffer[header_len] = '\0'; + ws->buffer_len = header_len; + ws->buffer[header_len] = '\0'; // We will mark the end of the header to ensure that strstr operations for parsing the headers don't fail. ESP_LOGD(TAG, "Read header chunk %d, current header size: %d", len, header_len); - } while (NULL == strstr(ws->buffer, "\r\n\r\n") && header_len < WS_BUFFER_SIZE); + } while (NULL == strstr(ws->buffer, delimiter) && header_len < WS_BUFFER_SIZE); + + char* delim_ptr = strstr(ws->buffer, delimiter); ws->http_status_code = get_http_status_code(ws->buffer); if (ws->http_status_code == -1) { @@ -272,6 +307,20 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int return -1; } + if (delim_ptr != NULL) { + size_t delim_pos = delim_ptr - ws->buffer + sizeof(delimiter) - 1; + size_t remaining_len = ws->buffer_len - delim_pos; + if (remaining_len > 0) { + memmove(ws->buffer, ws->buffer + delim_pos, remaining_len); + ws->buffer_len = remaining_len; + } else { +#ifdef CONFIG_WS_DYNAMIC_BUFFER + free(ws->buffer); + ws->buffer = NULL; +#endif + ws->buffer_len = 0; + } + } // See esp_crypto_sha1() arg size unsigned char expected_server_sha1[20]; // Size of base64 coded string see above @@ -291,10 +340,6 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int ESP_LOGE(TAG, "Invalid websocket key"); return -1; } -#ifdef CONFIG_WS_DYNAMIC_BUFFER - free(ws->buffer); - ws->buffer = NULL; -#endif return 0; } @@ -406,7 +451,7 @@ static int ws_read_payload(esp_transport_handle_t t, char *buffer, int len, int } // Receive and process payload - if (bytes_to_read != 0 && (rlen = esp_transport_read(ws->parent, buffer, bytes_to_read, timeout_ms)) <= 0) { + if (bytes_to_read != 0 && (rlen = esp_transport_read_internal(ws, buffer, bytes_to_read, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -437,7 +482,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t // Receive and process header first (based on header size) int header = 2; int mask_len = 4; - if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -451,7 +496,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d", ws->frame_state.opcode, mask, payload_len); if (payload_len == 126) { // headerLen += 2; - if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -459,7 +504,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t } else if (payload_len == 127) { // headerLen += 8; header = 8; - if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -474,7 +519,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t if (mask) { // Read and store mask - if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, mask_len, timeout_ms)) <= 0) { + if (payload_len != 0 && (rlen = esp_transport_read_internal(ws, buffer, mask_len, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; }