From 0ec3312edc2fddbaabf58018f1c59c1a56ba09d1 Mon Sep 17 00:00:00 2001 From: Mathieu Amiot Date: Fri, 14 Dec 2018 18:16:16 +0100 Subject: [PATCH 1/2] Rust 2018 + Parser overhaul --- Cargo.toml | 6 +-- benches/nitox_parser_benchmark.rs | 36 +++++++------- src/client.rs | 16 +++---- src/codec.rs | 30 +++++++----- src/net/connection.rs | 9 ++-- src/net/connection_inner.rs | 6 +-- src/net/mod.rs | 5 +- src/protocol/client/connect.rs | 29 +++++++----- src/protocol/client/pub_cmd.rs | 62 ++++++++++++------------ src/protocol/client/sub_cmd.rs | 49 ++++++++++--------- src/protocol/client/unsub_cmd.rs | 48 +++++++++++-------- src/protocol/mod.rs | 6 +-- src/protocol/op.rs | 73 ++++++++++++++++------------- src/protocol/server/info.rs | 26 ++++++---- src/protocol/server/message.rs | 33 ++++++------- src/protocol/server/server_error.rs | 7 +-- 16 files changed, 239 insertions(+), 202 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1da6b34..60c98d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ name = "nitox" readme = "README.md" repository = "https://github.com/YellowInnovation/nitox" version = "0.1.10" +edition = "2018" [[bench]] harness = false @@ -40,16 +41,13 @@ parking_lot = "0.7" rand = "0.6" serde = "1.0" serde_derive = "1.0" +serde_json = "1.0" tokio-codec = "0.1" tokio-executor = "0.1" tokio-tcp = "0.1" tokio-tls = "0.2" url = "1.7" -[dependencies.serde_json] -features = ["preserve_order"] -version = "1.0" - [dev-dependencies] criterion = "0.2" env_logger = "0.6" diff --git a/benches/nitox_parser_benchmark.rs b/benches/nitox_parser_benchmark.rs index 2b929f3..7aed536 100644 --- a/benches/nitox_parser_benchmark.rs +++ b/benches/nitox_parser_benchmark.rs @@ -8,15 +8,15 @@ use nitox::commands::*; fn benchmark_parser(c: &mut Criterion) { c.bench_function("connect_parse", |b| { - let cmd = b"CONNECT\t{\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"nitox\",\"lang\":\"rust\",\"version\":\"1.0.0\"}\r\n"; - b.iter(|| ConnectCommand::try_parse(cmd)) + let cmd = "CONNECT\t{\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"nitox\",\"lang\":\"rust\",\"version\":\"1.0.0\"}\r\n"; + b.iter(|| ConnectCommand::try_parse(cmd.into())) }); c.bench_function("connect_write", |b| b.iter(|| ConnectCommand::default().into_vec())); c.bench_function("pub_parse", |b| { - let cmd = b"PUB\tFOO\t11\r\nHello NATS!\r\n"; - b.iter(|| PubCommand::try_parse(cmd)) + let cmd = "PUB\tFOO\t11\r\nHello NATS!\r\n"; + b.iter(|| PubCommand::try_parse(cmd.into())) }); c.bench_function("pub_write", |b| { @@ -25,13 +25,14 @@ fn benchmark_parser(c: &mut Criterion) { subject: String::new(), payload: bytes::Bytes::new(), reply_to: None, - }.into_vec() + } + .into_vec() }) }); c.bench_function("sub_parse", |b| { - let cmd = b"SUB\tFOO\tpouet\r\n"; - b.iter(|| SubCommand::try_parse(cmd)) + let cmd = "SUB\tFOO\tpouet\r\n"; + b.iter(|| SubCommand::try_parse(cmd.into())) }); c.bench_function("sub_write", |b| { @@ -40,13 +41,14 @@ fn benchmark_parser(c: &mut Criterion) { queue_group: None, sid: String::new(), subject: String::new(), - }.into_vec() + } + .into_vec() }) }); c.bench_function("unsub_parse", |b| { - let cmd = b"UNSUB\tpouet\r\n"; - b.iter(|| UnsubCommand::try_parse(cmd)) + let cmd = "UNSUB\tpouet\r\n"; + b.iter(|| UnsubCommand::try_parse(cmd.into())) }); c.bench_function("unsub_write", |b| { @@ -54,13 +56,14 @@ fn benchmark_parser(c: &mut Criterion) { UnsubCommand { max_msgs: None, sid: String::new(), - }.into_vec() + } + .into_vec() }) }); c.bench_function("info_parse", |b| { - let cmd = b"INFO\t{\"server_id\":\"test\",\"version\":\"1.3.0\",\"go\":\"go1.10.3\",\"host\":\"0.0.0.0\",\"port\":4222,\"max_payload\":4000,\"proto\":1,\"client_id\":1337}\r\n"; - b.iter(|| ServerInfo::try_parse(cmd)) + let cmd = "INFO\t{\"server_id\":\"test\",\"version\":\"1.3.0\",\"go\":\"go1.10.3\",\"host\":\"0.0.0.0\",\"port\":4222,\"max_payload\":4000,\"proto\":1,\"client_id\":1337}\r\n"; + b.iter(|| ServerInfo::try_parse(cmd.into())) }); c.bench_function("info_write", |b| { @@ -79,8 +82,8 @@ fn benchmark_parser(c: &mut Criterion) { }); c.bench_function("message_parse", |b| { - let cmd = b"MSG\tFOO\tpouet\t4\r\ntoto\r\n"; - b.iter(|| Message::try_parse(cmd)) + let cmd = "MSG\tFOO\tpouet\t4\r\ntoto\r\n"; + b.iter(|| Message::try_parse(cmd.into())) }); c.bench_function("message_write", |b| { @@ -90,7 +93,8 @@ fn benchmark_parser(c: &mut Criterion) { sid: String::new(), reply_to: None, payload: bytes::Bytes::new(), - }.into_vec() + } + .into_vec() }) }); } diff --git a/src/client.rs b/src/client.rs index 9be1f9c..a54f231 100644 --- a/src/client.rs +++ b/src/client.rs @@ -17,9 +17,9 @@ use std::{ use tokio_executor; use url::Url; -use error::NatsError; -use net::*; -use protocol::{commands::*, Op}; +use crate::error::NatsError; +use crate::net::*; +use crate::protocol::{commands::*, Op}; /// Sink (write) part of a TCP stream type NatsSink = stream::SplitSink; @@ -113,10 +113,10 @@ impl NatsClientMultiplexer { (NatsClientMultiplexer { subs_tx, other_tx }, other_rx) } - pub fn for_sid(&self, sid: NatsSubscriptionId) -> impl Stream + Send + Sync { + pub fn for_sid(&self, sid: &str) -> impl Stream + Send + Sync { let (tx, rx) = mpsc::unbounded(); (*self.subs_tx.write()).insert( - sid, + sid.into(), SubscriptionSink { tx, max_count: None, @@ -330,7 +330,7 @@ impl NatsClient { let inner_rx = self.rx.clone(); let sid = cmd.sid.clone(); self.tx.send(Op::SUB(cmd)).and_then(move |_| { - let stream = inner_rx.for_sid(sid.clone()).and_then(move |msg| { + let stream = inner_rx.for_sid(&sid).and_then(move |msg| { { let mut stx = inner_rx.subs_tx.write(); let mut delete = None; @@ -390,7 +390,7 @@ impl NatsClient { let sid = sub_cmd.sid.clone(); let unsub_cmd = UnsubCommand { - sid: sub_cmd.sid.clone(), + sid: sid.clone(), max_msgs: Some(1), }; @@ -400,7 +400,7 @@ impl NatsClient { let stream = self .rx - .for_sid(sid.clone()) + .for_sid(&sid) .inspect(|msg| debug!(target: "nitox", "Request saw msg in multiplexed stream {:#?}", msg)) .take(1) .into_future() diff --git a/src/codec.rs b/src/codec.rs index 0fea8d6..e40677e 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -1,6 +1,6 @@ +use crate::error::NatsError; +use crate::protocol::Op; use bytes::{BufMut, BytesMut}; -use error::NatsError; -use protocol::{CommandError, Op}; use tokio_codec::{Decoder, Encoder}; /// `tokio-codec` implementation of the protocol parsing @@ -41,15 +41,24 @@ impl Decoder for OpCodec { return Ok(None); } + debug!(target: "nitox", "next index: {}", self.next_index); debug!(target: "nitox", "codec buffer is {:?}", buf); + // Let's check if we find a blank space at the beginning if let Some(command_offset) = buf[self.next_index..] .iter() .position(|b| *b == b' ' || *b == b'\t' || *b == b'\r') { let command_end = self.next_index + command_offset; + + debug!(target: "nitox", "command end: {}", command_end); debug!(target: "nitox", "codec detected command name {:?}", &buf[..command_end]); + if !Op::command_exists(&buf[..command_end]) { + debug!(target: "nitox", "command was incomplete"); + return Ok(None); + } + if let Some(command_body_offset) = buf[command_end..].windows(2).position(|w| w == b"\r\n") { let mut end_buf_pos = command_end + command_body_offset + 2; @@ -65,26 +74,23 @@ impl Decoder for OpCodec { } debug!(target: "nitox", "codec detected command body {:?}", &buf[..end_buf_pos]); - match Op::from_bytes(&buf[..command_end], &buf[..end_buf_pos]) { - Err(CommandError::IncompleteCommandError) => { - debug!(target: "nitox", "command was incomplete"); - self.next_index = buf.len(); - Ok(None) - } + + let cmd_buf = buf.split_to(end_buf_pos); + debug!(target: "nitox", "buffer now contains {:?}", buf); + self.next_index = 0; + + match Op::from_bytes(cmd_buf.freeze(), command_end) { Ok(op) => { debug!(target: "nitox", "codec parsed command {:#?}", op); - let _ = buf.split_to(end_buf_pos); - debug!(target: "nitox", "buffer now contains {:?}", buf); - self.next_index = 0; Ok(Some(op)) } Err(e) => { debug!(target: "nitox", "command couldn't be parsed {}", e); - self.next_index = 0; Err(e.into()) } } } else { + debug!(target: "nitox", "command was incomplete"); Ok(None) } } else { diff --git a/src/net/connection.rs b/src/net/connection.rs index a17c1d6..fb03020 100644 --- a/src/net/connection.rs +++ b/src/net/connection.rs @@ -6,8 +6,8 @@ use parking_lot::RwLock; use std::{net::SocketAddr, sync::Arc}; use tokio_executor; -use error::NatsError; -use protocol::Op; +use crate::error::NatsError; +use crate::protocol::Op; use super::connection_inner::NatsConnectionInner; @@ -54,7 +54,7 @@ impl NatsConnection { let inner_arc = Arc::clone(&self.inner); let inner_state = Arc::clone(&self.state); let is_tls = self.is_tls; - let maybe_host = self.host.clone(); + let maybe_host: Option = self.host.clone(); NatsConnectionInner::connect_tcp(&self.addr) .and_then(move |socket| { if is_tls { @@ -66,7 +66,8 @@ impl NatsConnection { } else { Either::B(future::ok(NatsConnectionInner::from(socket))) } - }).and_then(move |inner| { + }) + .and_then(move |inner| { { *inner_arc.write() = inner; *inner_state.write() = NatsConnectionState::Connected; diff --git a/src/net/connection_inner.rs b/src/net/connection_inner.rs index 8a5126a..94bb914 100644 --- a/src/net/connection_inner.rs +++ b/src/net/connection_inner.rs @@ -1,16 +1,16 @@ -use codec::OpCodec; +use crate::codec::OpCodec; +use crate::protocol::Op; use futures::{ future::{self, Either}, prelude::*, }; use native_tls::TlsConnector as NativeTlsConnector; -use protocol::Op; use std::net::SocketAddr; use tokio_codec::{Decoder, Framed}; use tokio_tcp::TcpStream; use tokio_tls::{TlsConnector, TlsStream}; -use error::NatsError; +use crate::error::NatsError; /// Inner raw stream enum over TCP and TLS/TCP #[derive(Debug)] diff --git a/src/net/mod.rs b/src/net/mod.rs index a18b3b1..c107cf2 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -6,7 +6,7 @@ use std::sync::Arc; pub(crate) mod connection; mod connection_inner; -use error::NatsError; +use crate::error::NatsError; use self::connection::NatsConnectionState; use self::connection_inner::*; @@ -34,7 +34,8 @@ pub(crate) fn connect_tls(host: String, addr: SocketAddr) -> impl Future Result { - match ::std::env::var("CARGO_PKG_VERSION") { - Ok(v) => Ok(v), - Err(_) => Ok("0.1.x".into()), - } + Ok(env!("CARGO_PKG_VERSION").into()) } fn default_lang(&self) -> Result { @@ -73,10 +70,16 @@ impl Command for ConnectCommand { const CMD_NAME: &'static [u8] = b"CONNECT"; fn into_vec(self) -> Result { - Ok(format!("CONNECT\t{}\r\n", json::to_string(&self)?).as_bytes().into()) + let json_cmd = json::to_vec(&self)?; + let mut cmd: BytesMut = BytesMut::with_capacity(10 + json_cmd.len()); + cmd.put("CONNECT\t"); + cmd.put(json_cmd); + cmd.put("\r\n"); + + Ok(cmd.freeze()) } - fn try_parse(buf: &[u8]) -> Result { + fn try_parse(buf: Bytes) -> Result { let len = buf.len(); if buf[len - 2..] != [b'\r', b'\n'] { @@ -94,22 +97,22 @@ impl Command for ConnectCommand { #[cfg(test)] mod tests { use super::{ConnectCommand, ConnectCommandBuilder}; - use protocol::Command; + use crate::protocol::Command; static DEFAULT_CONNECT: &'static str = "CONNECT\t{\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"nitox\",\"lang\":\"rust\",\"version\":\"1.0.0\"}\r\n"; #[test] fn it_parses() { - let parse_res = ConnectCommand::try_parse(DEFAULT_CONNECT.as_bytes()); + let parse_res = ConnectCommand::try_parse(DEFAULT_CONNECT.into()); assert!(parse_res.is_ok()); let cmd = parse_res.unwrap(); assert_eq!(cmd.verbose, false); assert_eq!(cmd.pedantic, false); assert_eq!(cmd.tls_required, false); assert!(cmd.name.is_some()); - assert_eq!(&cmd.name.unwrap(), "nitox"); - assert_eq!(&cmd.lang, "rust"); - assert_eq!(&cmd.version, "1.0.0"); + assert_eq!(cmd.name.unwrap(), "nitox"); + assert_eq!(cmd.lang, "rust"); + assert_eq!(cmd.version, "1.0.0"); } #[test] diff --git a/src/protocol/client/pub_cmd.rs b/src/protocol/client/pub_cmd.rs index 42e7742..00b039c 100644 --- a/src/protocol/client/pub_cmd.rs +++ b/src/protocol/client/pub_cmd.rs @@ -1,5 +1,5 @@ +use crate::protocol::{Command, CommandError}; use bytes::{BufMut, Bytes, BytesMut}; -use protocol::{Command, CommandError}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; /// The PUB message publishes the message payload to the given subject name, optionally supplying a reply subject. @@ -26,8 +26,7 @@ impl PubCommand { /// Generates a random `reply_to` `String` pub fn generate_reply_to() -> String { - let mut rng = thread_rng(); - rng.sample_iter(&Alphanumeric).take(16).collect() + thread_rng().sample_iter(&Alphanumeric).take(16).collect() } } @@ -35,22 +34,28 @@ impl Command for PubCommand { const CMD_NAME: &'static [u8] = b"PUB"; fn into_vec(self) -> Result { - let rt = if let Some(reply_to) = self.reply_to { - format!("\t{}", reply_to) - } else { - "".into() - }; - - let cmd_str = format!("PUB\t{}{}\t{}\r\n", self.subject, rt, self.payload.len()); - let mut bytes = BytesMut::with_capacity(cmd_str.len() + self.payload.len() + 2); - bytes.put(cmd_str.as_bytes()); + let (rt_len, rt) = self.reply_to.map_or((0, "".into()), |rp| (rp.len() + 1, rp)); + // Computes the string length of the payload_len by dividing the number par ln(10) + let size_len = ((self.payload.len() + 1) as f64 / std::f64::consts::LN_10).ceil() as usize; + let len = 9 + self.subject.len() + rt_len + size_len + self.payload.len(); + + let mut bytes = BytesMut::with_capacity(len); + bytes.put("PUB\t"); + bytes.put(self.subject); + if rt_len > 0 { + bytes.put(b'\t'); + bytes.put(rt); + } + bytes.put(b'\t'); + bytes.put(self.payload.len().to_string()); + bytes.put("\r\n"); bytes.put(self.payload); bytes.put("\r\n"); Ok(bytes.freeze()) } - fn try_parse(buf: &[u8]) -> Result { + fn try_parse(buf: Bytes) -> Result { let len = buf.len(); if buf[len - 2..] != [b'\r', b'\n'] { @@ -64,27 +69,28 @@ impl Command for PubCommand { let payload: Bytes = buf[payload_start + 2..len - 2].into(); - let whole_command = ::std::str::from_utf8(&buf[..payload_start])?; - let mut split = whole_command.split_whitespace(); + let mut split = buf[..payload_start].split(|c| *c == b' ' || *c == b'\t'); let cmd = split.next().ok_or_else(|| CommandError::CommandMalformed)?; // Check if we're still on the right command - if cmd.as_bytes() != Self::CMD_NAME { + if cmd != Self::CMD_NAME { return Err(CommandError::CommandMalformed); } - let payload_len: usize = split - .next_back() - .ok_or_else(|| CommandError::CommandMalformed)? - .parse()?; + let payload_len: usize = + std::str::from_utf8(split.next_back().ok_or_else(|| CommandError::CommandMalformed)?)?.parse()?; if payload.len() != payload_len { return Err(CommandError::CommandMalformed); } // Extract subject - let subject: String = split.next().ok_or_else(|| CommandError::CommandMalformed)?.into(); + let subject: String = + std::str::from_utf8(split.next().ok_or_else(|| CommandError::CommandMalformed)?)?.into(); - let reply_to: Option = split.next().map(|v| v.into()); + let reply_to: Option = match split.next() { + Some(v) => Some(std::str::from_utf8(v)?.into()), + _ => None, + }; Ok(PubCommand { subject, @@ -98,12 +104,6 @@ impl Command for PubCommand { } impl PubCommandBuilder { - pub fn auto_reply_to(&mut self) -> &mut Self { - let inbox = PubCommand::generate_reply_to(); - self.reply_to = Some(Some(inbox)); - self - } - fn validate(&self) -> Result<(), String> { if let Some(ref subj) = self.subject { check_cmd_arg!(subj, "subject"); @@ -122,16 +122,16 @@ impl PubCommandBuilder { #[cfg(test)] mod tests { use super::{PubCommand, PubCommandBuilder}; - use protocol::Command; + use crate::protocol::Command; static DEFAULT_PUB: &'static str = "PUB\tFOO\t11\r\nHello NATS!\r\n"; #[test] fn it_parses() { - let parse_res = PubCommand::try_parse(DEFAULT_PUB.as_bytes()); + let parse_res = PubCommand::try_parse(DEFAULT_PUB.into()); assert!(parse_res.is_ok()); let cmd = parse_res.unwrap(); - assert_eq!(&cmd.subject, "FOO"); + assert_eq!(cmd.subject, "FOO"); assert_eq!(&cmd.payload, "Hello NATS!"); assert!(cmd.reply_to.is_none()); } diff --git a/src/protocol/client/sub_cmd.rs b/src/protocol/client/sub_cmd.rs index df73348..c2766bf 100644 --- a/src/protocol/client/sub_cmd.rs +++ b/src/protocol/client/sub_cmd.rs @@ -1,5 +1,5 @@ -use bytes::Bytes; -use protocol::{Command, CommandError}; +use crate::protocol::{Command, CommandError}; +use bytes::{BufMut, Bytes, BytesMut}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; /// SUB initiates a subscription to a subject, optionally joining a distributed queue group. @@ -32,38 +32,45 @@ impl Command for SubCommand { const CMD_NAME: &'static [u8] = b"SUB"; fn into_vec(self) -> Result { - let qg = if let Some(queue_group) = self.queue_group { - format!("\t{}", queue_group) - } else { - "".into() - }; + let (qg_len, qg) = self.queue_group.map_or((0, "".into()), |qg| (qg.len() + 1, qg)); + let len = 7 + self.subject.len() + qg_len + self.sid.len(); + let mut bytes = BytesMut::with_capacity(len); + bytes.put("SUB\t"); + bytes.put(self.subject); + if qg_len > 0 { + bytes.put_u8(b'\t'); + bytes.put(qg); + } + bytes.put_u8(b'\t'); + bytes.put(self.sid); + bytes.put("\r\n"); - Ok(format!("SUB\t{}{}\t{}\r\n", self.subject, qg, self.sid) - .as_bytes() - .into()) + Ok(bytes.freeze()) } - fn try_parse(buf: &[u8]) -> Result { + fn try_parse(buf: Bytes) -> Result { let len = buf.len(); if buf[len - 2..] != [b'\r', b'\n'] { return Err(CommandError::IncompleteCommandError); } - let whole_command = ::std::str::from_utf8(&buf[..len - 2])?; - let mut split = whole_command.split_whitespace(); + let mut split = buf[..len - 2].split(|c| *c == b' ' || *c == b'\t'); let cmd = split.next().ok_or_else(|| CommandError::CommandMalformed)?; // Check if we're still on the right command - if cmd.as_bytes() != Self::CMD_NAME { + if cmd != Self::CMD_NAME { return Err(CommandError::CommandMalformed); } // Extract subject - let subject: String = split.next().ok_or_else(|| CommandError::CommandMalformed)?.into(); + let subject: String = std::str::from_utf8(split.next().ok_or_else(|| CommandError::CommandMalformed)?)?.into(); // Extract subscription id - let sid: String = split.next_back().ok_or_else(|| CommandError::CommandMalformed)?.into(); + let sid: String = std::str::from_utf8(split.next_back().ok_or_else(|| CommandError::CommandMalformed)?)?.into(); // Extract queue group if exists - let queue_group: Option = split.next().map(|v| v.into()); + let queue_group: Option = match split.next() { + Some(v) => Some(std::str::from_utf8(v)?.into()), + _ => None, + }; Ok(SubCommand { subject, @@ -92,17 +99,17 @@ impl SubCommandBuilder { #[cfg(test)] mod tests { use super::{SubCommand, SubCommandBuilder}; - use protocol::Command; + use crate::protocol::Command; static DEFAULT_SUB: &'static str = "SUB\tFOO\tpouet\r\n"; #[test] fn it_parses() { - let parse_res = SubCommand::try_parse(DEFAULT_SUB.as_bytes()); + let parse_res = SubCommand::try_parse(DEFAULT_SUB.into()); assert!(parse_res.is_ok()); let cmd = parse_res.unwrap(); - assert_eq!(&cmd.subject, "FOO"); - assert_eq!(&cmd.sid, "pouet") + assert_eq!(cmd.subject, "FOO"); + assert_eq!(cmd.sid, "pouet") } #[test] diff --git a/src/protocol/client/unsub_cmd.rs b/src/protocol/client/unsub_cmd.rs index 51a43a0..f6ceafa 100644 --- a/src/protocol/client/unsub_cmd.rs +++ b/src/protocol/client/unsub_cmd.rs @@ -1,5 +1,5 @@ -use bytes::Bytes; -use protocol::{commands::SubCommand, Command, CommandError}; +use crate::protocol::{commands::SubCommand, Command, CommandError}; +use bytes::{BufMut, Bytes, BytesMut}; /// UNSUB unsubcribes the connection from the specified subject, or auto-unsubscribes after the /// specified number of messages has been received. @@ -32,36 +32,44 @@ impl Command for UnsubCommand { const CMD_NAME: &'static [u8] = b"UNSUB"; fn into_vec(self) -> Result { - let mm = if let Some(max_msgs) = self.max_msgs { - format!("\t{}", max_msgs) - } else { - "".into() - }; + // Computes the string length of the payload_len by dividing the number par ln(10) + let (mm_len, mm) = self.max_msgs.map_or((0, 0), |mm| { + (((mm + 1) as f64 / std::f64::consts::LN_10).ceil() as usize, mm) + }); + + let len = 8 + self.sid.len() + mm_len; + + let mut bytes = BytesMut::with_capacity(len); + bytes.put("UNSUB\t"); + bytes.put(self.sid); + if mm_len > 0 { + bytes.put(b'\t'); + bytes.put(mm.to_string()); + } + bytes.put("\r\n"); - Ok(format!("UNSUB\t{}{}\r\n", self.sid, mm).as_bytes().into()) + Ok(bytes.freeze()) } - fn try_parse(buf: &[u8]) -> Result { + fn try_parse(buf: Bytes) -> Result { let len = buf.len(); if buf[len - 2..] != [b'\r', b'\n'] { return Err(CommandError::IncompleteCommandError); } - let whole_command = ::std::str::from_utf8(&buf[..len - 2])?; - let mut split = whole_command.split_whitespace(); + let mut split = buf[..len - 2].split(|c| *c == b' ' || *c == b'\t'); let cmd = split.next().ok_or_else(|| CommandError::CommandMalformed)?; // Check if we're still on the right command - if cmd.as_bytes() != Self::CMD_NAME { + if cmd != Self::CMD_NAME { return Err(CommandError::CommandMalformed); } - let sid: String = split.next().ok_or_else(|| CommandError::CommandMalformed)?.into(); + let sid: String = std::str::from_utf8(split.next().ok_or_else(|| CommandError::CommandMalformed)?)?.into(); - let max_msgs: Option = if let Some(mm) = split.next() { - Some(mm.parse()?) - } else { - None + let max_msgs: Option = match split.next() { + Some(mm) => Some(std::str::from_utf8(mm)?.parse()?), + _ => None, }; Ok(UnsubCommand { sid, max_msgs }) @@ -71,16 +79,16 @@ impl Command for UnsubCommand { #[cfg(test)] mod tests { use super::{UnsubCommand, UnsubCommandBuilder}; - use protocol::Command; + use crate::protocol::Command; static DEFAULT_UNSUB: &'static str = "UNSUB\tpouet\r\n"; #[test] fn it_parses() { - let parse_res = UnsubCommand::try_parse(DEFAULT_UNSUB.as_bytes()); + let parse_res = UnsubCommand::try_parse(DEFAULT_UNSUB.into()); assert!(parse_res.is_ok()); let cmd = parse_res.unwrap(); - assert_eq!(&cmd.sid, "pouet"); + assert_eq!(cmd.sid, "pouet"); assert!(cmd.max_msgs.is_none()); } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 69af1ce..28751b9 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -7,7 +7,7 @@ pub trait Command { /// Encodes the command into bytes fn into_vec(self) -> Result; /// Tries to parse a buffer into a command - fn try_parse(buf: &[u8]) -> Result + fn try_parse(buf: Bytes) -> Result where Self: Sized; } @@ -24,7 +24,7 @@ pub(crate) fn check_command_arg(s: &str) -> Result<(), ArgumentValidationError> macro_rules! check_cmd_arg { ($val:ident, $part:expr) => { - use protocol::{check_command_arg, ArgumentValidationError}; + use crate::protocol::{check_command_arg, ArgumentValidationError}; match check_command_arg($val) { Ok(_) => {} @@ -52,7 +52,7 @@ pub mod commands { client::{connect::*, pub_cmd::*, sub_cmd::*, unsub_cmd::*}, server::{info::*, message::*, server_error::ServerError}, }; - pub use Command; + pub use crate::Command; } #[cfg(test)] diff --git a/src/protocol/op.rs b/src/protocol/op.rs index 65f91f9..28a8e88 100644 --- a/src/protocol/op.rs +++ b/src/protocol/op.rs @@ -26,18 +26,6 @@ pub enum Op { ERR(ServerError), } -macro_rules! op_from_cmd { - ($buf:ident, $cmd:path, $op:path) => {{ - use protocol::CommandError; - - match $cmd(&$buf) { - Ok(c) => Ok($op(c)), - Err(CommandError::IncompleteCommandError) => return Err(CommandError::IncompleteCommandError), - Err(e) => return Err(e.into()), - } - }}; -} - impl Op { /// Transforms the OP into a byte slice pub fn into_bytes(self) -> Result { @@ -56,49 +44,68 @@ impl Op { } /// Tries to parse from a pair of command name and whole buffer - pub fn from_bytes(cmd_name: &[u8], buf: &[u8]) -> Result { - match cmd_name { - ServerInfo::CMD_NAME => op_from_cmd!(buf, ServerInfo::try_parse, Op::INFO), - ConnectCommand::CMD_NAME => op_from_cmd!(buf, ConnectCommand::try_parse, Op::CONNECT), - Message::CMD_NAME => op_from_cmd!(buf, Message::try_parse, Op::MSG), - PubCommand::CMD_NAME => op_from_cmd!(buf, PubCommand::try_parse, Op::PUB), - SubCommand::CMD_NAME => op_from_cmd!(buf, SubCommand::try_parse, Op::SUB), - UnsubCommand::CMD_NAME => op_from_cmd!(buf, UnsubCommand::try_parse, Op::UNSUB), + pub fn from_bytes(buf: Bytes, cmd_idx: usize) -> Result { + let mut cmd_name = vec![0; cmd_idx]; + cmd_name.copy_from_slice(&buf[..cmd_idx]); + + Ok(match &*cmd_name { + ServerInfo::CMD_NAME => Op::INFO(ServerInfo::try_parse(buf)?), + ConnectCommand::CMD_NAME => Op::CONNECT(ConnectCommand::try_parse(buf)?), + Message::CMD_NAME => Op::MSG(Message::try_parse(buf)?), + PubCommand::CMD_NAME => Op::PUB(PubCommand::try_parse(buf)?), + SubCommand::CMD_NAME => Op::SUB(SubCommand::try_parse(buf)?), + UnsubCommand::CMD_NAME => Op::UNSUB(UnsubCommand::try_parse(buf)?), b"PING" => { - if buf == b"PING\r\n" { - Ok(Op::PING) + if buf == "PING\r\n" { + Op::PING } else { - Err(CommandError::IncompleteCommandError) + return Err(CommandError::IncompleteCommandError); } } b"PONG" => { - if buf == b"PONG\r\n" { - Ok(Op::PONG) + if buf == "PONG\r\n" { + Op::PONG } else { - Err(CommandError::IncompleteCommandError) + return Err(CommandError::IncompleteCommandError); } } b"+OK" => { - if buf == b"+OK\r\n" { - Ok(Op::OK) + if buf == "+OK\r\n" { + Op::OK } else { - Err(CommandError::IncompleteCommandError) + return Err(CommandError::IncompleteCommandError); } } b"-ERR" => { if &buf[buf.len() - 2..] == b"\r\n" { - Ok(Op::ERR(ServerError::from(String::from_utf8(buf[1..].to_vec())?))) + Op::ERR(ServerError(std::str::from_utf8(&buf[1..])?.into())) } else { - Err(CommandError::IncompleteCommandError) + return Err(CommandError::IncompleteCommandError); } } _ => { if buf.len() > 7 { - Err(CommandError::CommandNotFoundOrSupported) + return Err(CommandError::CommandNotFoundOrSupported); } else { - Err(CommandError::IncompleteCommandError) + return Err(CommandError::IncompleteCommandError); } } + }) + } + + pub fn command_exists(cmd_name: &[u8]) -> bool { + match cmd_name { + ServerInfo::CMD_NAME => true, + ConnectCommand::CMD_NAME => true, + Message::CMD_NAME => true, + PubCommand::CMD_NAME => true, + SubCommand::CMD_NAME => true, + UnsubCommand::CMD_NAME => true, + b"PING" => true, + b"PONG" => true, + b"+OK" => true, + b"-ERR" => true, + _ => false, } } } diff --git a/src/protocol/server/info.rs b/src/protocol/server/info.rs index ac97a28..d33b26a 100644 --- a/src/protocol/server/info.rs +++ b/src/protocol/server/info.rs @@ -1,5 +1,5 @@ -use bytes::Bytes; -use protocol::{Command, CommandError}; +use crate::protocol::{Command, CommandError}; +use bytes::{BufMut, Bytes, BytesMut}; use serde_json as json; /// As soon as the server accepts a connection from the client, it will send information about itself and the @@ -68,10 +68,16 @@ impl Command for ServerInfo { const CMD_NAME: &'static [u8] = b"INFO"; fn into_vec(self) -> Result { - Ok(format!("INFO\t{}\r\n", json::to_string(&self)?).as_bytes().into()) + let json_cmd = json::to_vec(&self)?; + let mut cmd: BytesMut = BytesMut::with_capacity(7 + json_cmd.len()); + cmd.put("INFO\t"); + cmd.put(json_cmd); + cmd.put("\r\n"); + + Ok(cmd.freeze()) } - fn try_parse(buf: &[u8]) -> Result { + fn try_parse(buf: Bytes) -> Result { let len = buf.len(); if buf[len - 2..] != [b'\r', b'\n'] { @@ -89,20 +95,20 @@ impl Command for ServerInfo { #[cfg(test)] mod tests { use super::{ServerInfo, ServerInfoBuilder}; - use protocol::Command; + use crate::protocol::Command; static DEFAULT_INFO: &'static str = "INFO\t{\"server_id\":\"test\",\"version\":\"1.3.0\",\"go\":\"go1.10.3\",\"host\":\"0.0.0.0\",\"port\":4222,\"max_payload\":4000,\"proto\":1,\"client_id\":1337}\r\n"; #[test] fn it_parses() { - let parse_res = ServerInfo::try_parse(DEFAULT_INFO.as_bytes()); + let parse_res = ServerInfo::try_parse(DEFAULT_INFO.into()); assert!(parse_res.is_ok()); let cmd = parse_res.unwrap(); - assert_eq!(&cmd.server_id, "test"); - assert_eq!(&cmd.version, "1.3.0"); + assert_eq!(cmd.server_id, "test"); + assert_eq!(cmd.version, "1.3.0"); assert_eq!(cmd.proto, Some(1u8)); - assert_eq!(&cmd.go, "go1.10.3"); - assert_eq!(&cmd.host, "0.0.0.0"); + assert_eq!(cmd.go, "go1.10.3"); + assert_eq!(cmd.host, "0.0.0.0"); assert_eq!(cmd.port, 4222u32); assert_eq!(cmd.max_payload, 4000u32); assert!(cmd.client_id.is_some()); diff --git a/src/protocol/server/message.rs b/src/protocol/server/message.rs index c1a216d..d7862ae 100644 --- a/src/protocol/server/message.rs +++ b/src/protocol/server/message.rs @@ -1,5 +1,5 @@ +use crate::protocol::{Command, CommandError}; use bytes::{BufMut, Bytes, BytesMut}; -use protocol::{Command, CommandError}; /// The MSG protocol message is used to deliver an application message to the client. #[derive(Debug, Clone, PartialEq, Builder)] @@ -44,7 +44,7 @@ impl Command for Message { Ok(bytes.freeze()) } - fn try_parse(buf: &[u8]) -> Result { + fn try_parse(buf: Bytes) -> Result { let len = buf.len(); if buf[len - 2..] != [b'\r', b'\n'] { @@ -58,29 +58,30 @@ impl Command for Message { let payload: Bytes = buf[payload_start + 2..len - 2].into(); - let whole_command = ::std::str::from_utf8(&buf[..payload_start])?; - let mut split = whole_command.split_whitespace(); + let mut split = buf[..payload_start].split(|c| *c == b' ' || *c == b'\t'); let cmd = split.next().ok_or_else(|| CommandError::CommandMalformed)?; // Check if we're still on the right command - if cmd.as_bytes() != Self::CMD_NAME { + if cmd != Self::CMD_NAME { return Err(CommandError::CommandMalformed); } - let payload_len: usize = split - .next_back() - .ok_or_else(|| CommandError::CommandMalformed)? - .parse()?; + let payload_len: usize = + std::str::from_utf8(split.next_back().ok_or_else(|| CommandError::CommandMalformed)?)?.parse()?; if payload.len() != payload_len { return Err(CommandError::CommandMalformed); } // Extract subject - let subject: String = split.next().ok_or_else(|| CommandError::CommandMalformed)?.into(); + let subject: String = + std::str::from_utf8(split.next().ok_or_else(|| CommandError::CommandMalformed)?)?.into(); - let sid: String = split.next().ok_or_else(|| CommandError::CommandMalformed)?.into(); + let sid: String = std::str::from_utf8(split.next().ok_or_else(|| CommandError::CommandMalformed)?)?.into(); - let reply_to: Option = split.next().map(|v| v.into()); + let reply_to: Option = match split.next() { + Some(v) => Some(std::str::from_utf8(v)?.into()), + _ => None, + }; Ok(Message { subject, @@ -113,18 +114,18 @@ impl MessageBuilder { #[cfg(test)] mod tests { use super::{Message, MessageBuilder}; - use protocol::Command; + use crate::protocol::Command; static DEFAULT_MSG: &'static str = "MSG\tFOO\tpouet\t4\r\ntoto\r\n"; #[test] fn it_parses() { - let parse_res = Message::try_parse(DEFAULT_MSG.as_bytes()); + let parse_res = Message::try_parse(DEFAULT_MSG.into()); assert!(parse_res.is_ok()); let cmd = parse_res.unwrap(); assert!(cmd.reply_to.is_none()); - assert_eq!(&cmd.subject, "FOO"); - assert_eq!(&cmd.sid, "pouet"); + assert_eq!(cmd.subject, "FOO"); + assert_eq!(cmd.sid, "pouet"); assert_eq!(cmd.payload, "toto"); } diff --git a/src/protocol/server/server_error.rs b/src/protocol/server/server_error.rs index 93b7635..b17a7b6 100644 --- a/src/protocol/server/server_error.rs +++ b/src/protocol/server/server_error.rs @@ -5,12 +5,7 @@ use std::fmt; /// /// Handling of these errors usually has to be done asynchronously. #[derive(Debug, PartialEq, Clone)] -pub struct ServerError(String); -impl From for ServerError { - fn from(s: String) -> Self { - ServerError(s) - } -} +pub struct ServerError(pub String); impl fmt::Display for ServerError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { From 5cc479693fbca0b1339e2ccc386be45139d09edd Mon Sep 17 00:00:00 2001 From: Jeremy Lempereur Date: Thu, 21 Feb 2019 17:09:34 +0100 Subject: [PATCH 2/2] Update src/protocol/client/unsub_cmd.rs Co-Authored-By: OtaK <1262712+OtaK@users.noreply.github.com> --- src/protocol/client/unsub_cmd.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/protocol/client/unsub_cmd.rs b/src/protocol/client/unsub_cmd.rs index f6ceafa..100f2eb 100644 --- a/src/protocol/client/unsub_cmd.rs +++ b/src/protocol/client/unsub_cmd.rs @@ -32,7 +32,7 @@ impl Command for UnsubCommand { const CMD_NAME: &'static [u8] = b"UNSUB"; fn into_vec(self) -> Result { - // Computes the string length of the payload_len by dividing the number par ln(10) + // Computes the string length of the payload_len by dividing the number by ln(10) let (mm_len, mm) = self.max_msgs.map_or((0, 0), |mm| { (((mm + 1) as f64 / std::f64::consts::LN_10).ceil() as usize, mm) });