diff --git a/src/main.rs b/src/main.rs index f522375..0a3623e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -151,12 +151,6 @@ async fn handle_connection(stream: TcpStream, db: Db) -> Result<()> { loop { // Check for pubsub messages if subscribed if client.is_subscribed() { - // First, process any remaining data in the buffer before waiting - if !buffer.is_empty() { - process_buffer(&mut buffer, &mut writer, &db, &mut client).await?; - writer.flush().await?; - continue; - } tokio::select! { // Wait for pubsub messages Some(msg) = client.pubsub_rx.recv() => { @@ -186,11 +180,6 @@ async fn handle_connection(stream: TcpStream, db: Db) -> Result<()> { } } else { // Not subscribed - simple read loop - // First, process any remaining data in the buffer - if !buffer.is_empty() { - process_buffer(&mut buffer, &mut writer, &db, &mut client).await?; - continue; - } let n = reader.read_buf(&mut buffer).await?; if n == 0 { println!("Connection closed"); diff --git a/src/resp/parser.rs b/src/resp/parser.rs index 5551323..a534ddf 100644 --- a/src/resp/parser.rs +++ b/src/resp/parser.rs @@ -50,13 +50,25 @@ fn parse_bulk_string(buffer: BytesMut) -> Result<(Value, usize)> { return Err(anyhow::anyhow!("Invalid bulk string format {:?}", buffer)); }; - let end_of_bulk_str = bytes_consumed + bulk_str_len as usize; + if bulk_str_len == -1 { + return Ok((Value::NullBulk, bytes_consumed)); + } + if bulk_str_len < 0 { + return Err(anyhow::anyhow!("Invalid bulk string length")); + } + + let bulk_str_len = bulk_str_len as usize; + let end_of_bulk_str = bytes_consumed + bulk_str_len; let total_parsed = end_of_bulk_str + 2; - Ok(( - Value::BulkString(String::from_utf8(buffer[bytes_consumed..end_of_bulk_str].to_vec())?), - total_parsed, - )) + if buffer.len() < total_parsed { + return Err(anyhow::anyhow!("Incomplete bulk string data")); + } + + let data = &buffer[bytes_consumed..end_of_bulk_str]; + let s = String::from_utf8(data.to_vec())?; + + Ok((Value::BulkString(s), total_parsed)) } fn read_until_crlf(buffer: &[u8]) -> Option<(&[u8], usize)> {