Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion core/services/memcached/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ all-features = true

[dependencies]
fastpool = "1.0.2"
http = { workspace = true }
opendal-core = { path = "../../core", version = "0.55.0", default-features = false }
serde = { workspace = true, features = ["derive"] }
tokio = { workspace = true, features = ["net", "io-util"] }
url = "2.5.7"

[dev-dependencies]
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
84 changes: 53 additions & 31 deletions core/services/memcached/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use std::borrow::Cow;
use std::sync::Arc;
use url::Url;

use opendal_core::raw::*;
use opendal_core::*;
Expand Down Expand Up @@ -94,53 +96,73 @@ impl Builder for MemcachedBuilder {
type Config = MemcachedConfig;

fn build(self) -> Result<impl Access> {
let endpoint = self.config.endpoint.clone().ok_or_else(|| {
let endpoint_raw = self.config.endpoint.clone().ok_or_else(|| {
Error::new(ErrorKind::ConfigInvalid, "endpoint is empty")
.with_context("service", MEMCACHED_SCHEME)
})?;
let uri = http::Uri::try_from(&endpoint).map_err(|err| {

let url_str = if !endpoint_raw.contains("://") {
Cow::Owned(format!("tcp://{}", endpoint_raw))
} else {
Cow::Borrowed(endpoint_raw.as_str())
};

let parsed = Url::parse(&url_str).map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "endpoint is invalid")
.with_context("service", MEMCACHED_SCHEME)
.with_context("endpoint", &endpoint)
.with_context("endpoint", &endpoint_raw)
.set_source(err)
})?;

match uri.scheme_str() {
// If scheme is none, we will use tcp by default.
None => (),
Some(scheme) => {
// We only support tcp by now.
if scheme != "tcp" {
let endpoint = match parsed.scheme() {
"tcp" => {
let host = parsed.host_str().ok_or_else(|| {
Error::new(ErrorKind::ConfigInvalid, "tcp endpoint doesn't have host")
.with_context("service", MEMCACHED_SCHEME)
.with_context("endpoint", &endpoint_raw)
})?;
let port = parsed.port().ok_or_else(|| {
Error::new(ErrorKind::ConfigInvalid, "tcp endpoint doesn't have port")
.with_context("service", MEMCACHED_SCHEME)
.with_context("endpoint", &endpoint_raw)
})?;
Endpoint::Tcp(format!("{host}:{port}"))
}

#[cfg(unix)]
"unix" => {
let path = parsed.path();
if path.is_empty() {
return Err(Error::new(
ErrorKind::ConfigInvalid,
"endpoint is using invalid scheme",
"unix endpoint doesn't have path",
)
.with_context("service", MEMCACHED_SCHEME)
.with_context("endpoint", &endpoint)
.with_context("scheme", scheme.to_string()));
.with_context("endpoint", &endpoint_raw));
}
Endpoint::Unix(path.to_string())
}
};

let host = if let Some(host) = uri.host() {
host.to_string()
} else {
return Err(
Error::new(ErrorKind::ConfigInvalid, "endpoint doesn't have host")
.with_context("service", MEMCACHED_SCHEME)
.with_context("endpoint", &endpoint),
);
};
let port = if let Some(port) = uri.port_u16() {
port
} else {
return Err(
Error::new(ErrorKind::ConfigInvalid, "endpoint doesn't have port")
.with_context("service", MEMCACHED_SCHEME)
.with_context("endpoint", &endpoint),
);
#[cfg(not(unix))]
"unix" => {
return Err(Error::new(
ErrorKind::ConfigInvalid,
"unix socket is not supported on this platform",
)
.with_context("service", MEMCACHED_SCHEME)
.with_context("endpoint", &endpoint_raw));
}

scheme => {
return Err(Error::new(
ErrorKind::ConfigInvalid,
"endpoint is using invalid scheme, only tcp and unix are supported",
)
.with_context("service", MEMCACHED_SCHEME)
.with_context("endpoint", &endpoint_raw)
.with_context("scheme", scheme));
}
};
let endpoint = format!("{host}:{port}",);

let root = normalize_root(self.config.root.unwrap_or_else(|| "/".to_string()).as_str());

Expand Down
12 changes: 6 additions & 6 deletions core/services/memcached/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
// specific language governing permissions and limitations
// under the License.

use crate::core::SocketStream;
use opendal_core::raw::*;
use opendal_core::*;
use tokio::io;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::net::TcpStream;

pub(super) mod constants {
pub const OK_STATUS: u16 = 0x0;
Expand Down Expand Up @@ -60,7 +60,7 @@ pub struct PacketHeader {
}

impl PacketHeader {
pub async fn write(self, writer: &mut TcpStream) -> io::Result<()> {
pub async fn write(self, writer: &mut SocketStream) -> io::Result<()> {
writer.write_u8(self.magic).await?;
writer.write_u8(self.opcode).await?;
writer.write_u16(self.key_length).await?;
Expand All @@ -73,7 +73,7 @@ impl PacketHeader {
Ok(())
}

pub async fn read(reader: &mut TcpStream) -> Result<PacketHeader, io::Error> {
pub async fn read(reader: &mut SocketStream) -> Result<PacketHeader, io::Error> {
let header = PacketHeader {
magic: reader.read_u8().await?,
opcode: reader.read_u8().await?,
Expand All @@ -98,11 +98,11 @@ pub struct Response {

#[derive(Debug)]
pub struct Connection {
io: BufReader<TcpStream>,
io: BufReader<SocketStream>,
}

impl Connection {
pub fn new(io: TcpStream) -> Self {
pub fn new(io: SocketStream) -> Self {
Self {
io: BufReader::new(io),
}
Expand Down Expand Up @@ -246,7 +246,7 @@ impl Connection {
}
}

pub async fn parse_response(reader: &mut TcpStream) -> Result<Response> {
pub async fn parse_response(reader: &mut SocketStream) -> Result<Response> {
let header = PacketHeader::read(reader).await.map_err(new_std_io_error)?;

if header.vbucket_id_or_status != constants::OK_STATUS
Expand Down
106 changes: 97 additions & 9 deletions core/services/memcached/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,104 @@ use fastpool::ObjectStatus;
use fastpool::bounded;
use opendal_core::raw::*;
use opendal_core::*;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;

use super::binary;

#[derive(Debug)]
pub enum SocketStream {
Tcp(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
}

impl SocketStream {
pub async fn connect_tcp(addr_str: &str) -> io::Result<Self> {
let socket_addr: SocketAddr = addr_str
.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let stream = TcpStream::connect(socket_addr).await?;
Ok(SocketStream::Tcp(stream))
}

#[cfg(unix)]
pub async fn connect_unix(path: &str) -> io::Result<Self> {
let stream = UnixStream::connect(path).await?;
Ok(SocketStream::Unix(stream))
}
}

impl AsyncRead for SocketStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
SocketStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(unix)]
SocketStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
}
}
}

impl AsyncWrite for SocketStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
SocketStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(unix)]
SocketStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
SocketStream::Tcp(s) => Pin::new(s).poll_flush(cx),
#[cfg(unix)]
SocketStream::Unix(s) => Pin::new(s).poll_flush(cx),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
SocketStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(unix)]
SocketStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
}
}
}

/// Endpoint for memcached connection.
#[derive(Clone, Debug)]
pub enum Endpoint {
Tcp(String), // host:port
#[cfg(unix)]
Unix(String), // socket path
}

/// A connection manager for `memcache_async::ascii::Protocol`.
#[derive(Clone)]
struct MemcacheConnectionManager {
address: String,
endpoint: Endpoint,
username: Option<String>,
password: Option<String>,
}

impl MemcacheConnectionManager {
fn new(address: &str, username: Option<String>, password: Option<String>) -> Self {
fn new(endpoint: Endpoint, username: Option<String>, password: Option<String>) -> Self {
Self {
address: address.to_string(),
endpoint,
username,
password,
}
Expand All @@ -48,11 +130,17 @@ impl ManageObject for MemcacheConnectionManager {
type Object = binary::Connection;
type Error = Error;

/// TODO: Implement unix stream support.
async fn create(&self) -> Result<Self::Object, Self::Error> {
let conn = TcpStream::connect(&self.address)
.await
.map_err(new_std_io_error)?;
let conn = match &self.endpoint {
Endpoint::Tcp(addr) => SocketStream::connect_tcp(addr)
.await
.map_err(new_std_io_error)?,
#[cfg(unix)]
Endpoint::Unix(path) => SocketStream::connect_unix(path)
.await
.map_err(new_std_io_error)?,
};

let mut conn = binary::Connection::new(conn);

if let (Some(username), Some(password)) = (self.username.as_ref(), self.password.as_ref()) {
Expand Down Expand Up @@ -81,15 +169,15 @@ pub struct MemcachedCore {

impl MemcachedCore {
pub fn new(
endpoint: String,
endpoint: Endpoint,
username: Option<String>,
password: Option<String>,
default_ttl: Option<Duration>,
connection_pool_max_size: Option<usize>,
) -> Self {
let conn = bounded::Pool::new(
bounded::PoolConfig::new(connection_pool_max_size.unwrap_or(10)),
MemcacheConnectionManager::new(endpoint.as_str(), username, password),
MemcacheConnectionManager::new(endpoint, username, password),
);

Self { default_ttl, conn }
Expand Down
Loading