Skip to content
Open

test #469

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions src/acl.c
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,17 @@ void ACLCopyUser(user *dst, user *src) {
}
}

/* Set the user that client 'c' is authenticated as, performing any necessary
* bookkeeping for the switch. In particular, any pending BCAST tracking
* invalidations are flushed under the client's current ACL identity before
* c->user changes, so they are not re-filtered by the new user's key
* permissions in beforeSleep. */
void clientSetUser(client *c, user *new_user) {
if (c->user != new_user)
trackingBroadcastFlushClientPrefixes(c);
c->user = new_user;
}

/* Given a command ID, this function set by reference 'word' and 'bit'
* so that user->allowed_commands[word] will address the right word
* where the corresponding bit for the provided ID is stored, and
Expand Down Expand Up @@ -1497,7 +1508,7 @@ void addAuthErrReply(client *c, robj *err) {
int checkPasswordBasedAuth(client *c, robj *username, robj *password) {
if (ACLCheckUserCredentials(username,password) == C_OK) {
c->authenticated = 1;
c->user = ACLGetUserByName(username->ptr,sdslen(username->ptr));
clientSetUser(c, ACLGetUserByName(username->ptr,sdslen(username->ptr)));
moduleNotifyUserChanged(c);
return AUTH_OK;
} else {
Expand Down Expand Up @@ -2147,6 +2158,12 @@ sds ACLStringSetUser(user *u, sds username, sds *argv, int argc) {
* disconnected if (some of) their channel permissions were revoked. */
if (u) {
ACLKillPubsubClientsIfNeeded(tempu, u);
/* Deliver pending BCAST tracking invalidations under the user's
* current permissions before overwriting them in place below.
* Otherwise beforeSleep would re-filter the already accumulated keys
* by the new (possibly stricter) permissions and drop invalidations
* for keys the client could previously read. */
trackingBroadcastInvalidationMessages(u);
}

/* Overwrite the user with the temporary user we modified above. */
Expand Down Expand Up @@ -2439,6 +2456,14 @@ sds ACLLoadFromFile(const char *filename) {

/* Check if we found errors and react accordingly. */
if (sdslen(errors) == 0) {
/* Deliver pending BCAST tracking invalidations under the pre-load ACL
* identities before mutating any user. In particular DefaultUser is
* overwritten in place below, which would otherwise cause its pending
* invalidations to be re-filtered by the new permissions in
* beforeSleep. A whole-table flush is appropriate here since the load
* may change many users at once. */
trackingBroadcastInvalidationMessages(NULL);

/* The default user pointer is referenced in different places: instead
* of replacing such occurrences it is much simpler to copy the new
* default user configuration in the old one. */
Expand Down Expand Up @@ -2481,7 +2506,7 @@ sds ACLLoadFromFile(const char *filename) {
deauthenticateAndCloseClient(c);
continue;
}
c->user = new;
clientSetUser(c, new);
}

if (user_channels)
Expand Down Expand Up @@ -3241,7 +3266,7 @@ static void internalAuth(client *c) {
c->authenticated = 1;
/* Set the user to the unrestricted user, if it is not already set (default). */
if (c->user != NULL) {
c->user = NULL;
clientSetUser(c, NULL);
moduleNotifyUserChanged(c);
}
addReply(c, shared.ok);
Expand Down
2 changes: 1 addition & 1 deletion src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -10809,8 +10809,8 @@ static int authenticateClientWithUser(RedisModuleCtx *ctx, user *user, RedisModu

moduleNotifyUserChanged(ctx->client);

ctx->client->user = user;
ctx->client->authenticated = 1;
clientSetUser(ctx->client, user);

if (clientHasModuleAuthInProgress(ctx->client)) {
ctx->client->flags |= CLIENT_MODULE_AUTH_HAS_RESULT;
Expand Down
6 changes: 4 additions & 2 deletions src/networking.c
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ void linkClient(client *c) {
static void clientSetDefaultAuth(client *c) {
/* If the default user does not require authentication, the user is
* directly authenticated. */
c->user = DefaultUser;
clientSetUser(c, DefaultUser);
c->authenticated = (c->user->flags & USER_FLAG_NOPASS) &&
!(c->user->flags & USER_FLAG_DISABLED);
}
Expand Down Expand Up @@ -193,6 +193,7 @@ client *createClient(connection *conn) {
c->ctime = c->lastinteraction = server.unixtime;
c->io_lastinteraction = 0;
c->duration = 0;
c->user = DefaultUser; /* Set a safe default value: clientSetDefaultAuth reads c->user. */
clientSetDefaultAuth(c);
c->replstate = REPL_STATE_NONE;
c->repl_start_cmd_stream_on_ack = 0;
Expand Down Expand Up @@ -1614,8 +1615,8 @@ void clientAcceptHandler(connection *conn) {
if (username != NULL) {
user *u = ACLGetUserByName(username, sdslen(username));
if (u && !(u->flags & USER_FLAG_DISABLED)) {
c->user = u;
c->authenticated = 1;
clientSetUser(c, u);
moduleNotifyUserChanged(c);
serverLog(LL_VERBOSE, "TLS: Auto-authenticated client as %s",
server.hide_user_data_from_log ? "*redacted*" : u->name);
Expand Down Expand Up @@ -2073,6 +2074,7 @@ void clearClientConnectionState(client *c) {
}

void deauthenticateAndCloseClient(client *c) {
disableTracking(c);
c->user = DefaultUser;
c->authenticated = 0;
/* We will write replies to this client later, so we can't
Expand Down
2 changes: 1 addition & 1 deletion src/server.c
Original file line number Diff line number Diff line change
Expand Up @@ -2008,7 +2008,7 @@ void beforeSleep(struct aeEventLoop *eventLoop) {

/* Send the invalidation messages to clients participating to the
* client side caching protocol in broadcasting (BCAST) mode. */
trackingBroadcastInvalidationMessages();
trackingBroadcastInvalidationMessages(NULL);

/* Record time consumption of AOF writing. */
monotime aof_start_time = getMonotonicUs();
Expand Down
4 changes: 3 additions & 1 deletion src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -3360,7 +3360,9 @@ void trackingLimitUsedSlots(void);
uint64_t trackingGetTotalItems(void);
uint64_t trackingGetTotalKeys(void);
uint64_t trackingGetTotalPrefixes(void);
void trackingBroadcastInvalidationMessages(void);
void trackingBroadcastInvalidationMessages(user *u);
void trackingBroadcastFlushClientPrefixes(client *c);
void clientSetUser(client *c, user *new_user);
int checkPrefixCollisionsOrReply(client *c, robj **prefix, size_t numprefix);

/* List data type */
Expand Down
198 changes: 143 additions & 55 deletions src/tracking.c
Original file line number Diff line number Diff line change
Expand Up @@ -551,32 +551,32 @@ void trackingLimitUsedSlots(void) {
timeout_counter++;
}

/* Generate Redis protocol for an array containing all the key names
* in the 'keys' radix tree. If the client is not NULL, the list will not
* include keys that were modified the last time by this client, in order
* to implement the NOLOOP option.
/* Build the RESP array of invalidated key names in 'keys', filtered by:
* - ACL key permissions of user 'u' (NULL means all keys are permitted).
* - NOLOOP: if 'noloop_client' is non-NULL, keys last modified by
* that client are excluded.
*
* If the resulting array would be empty, NULL is returned instead. */
sds trackingBuildBroadcastReply(client *c, rax *keys) {
sds trackingBuildBroadcastReply(user *u, client *noloop_client, rax *keys) {
raxIterator ri;
uint64_t count;
uint64_t count = 0;

if (c == NULL) {
count = raxSize(keys);
} else {
count = 0;
raxStart(&ri,keys);
raxSeek(&ri,"^",NULL,0);
while(raxNext(&ri)) {
if (ri.data != c) count++;
}
raxStop(&ri);

if (count == 0) return NULL;
raxStart(&ri,keys);
raxSeek(&ri,"^",NULL,0);
while(raxNext(&ri)) {
if (noloop_client && ri.data == noloop_client)
continue;
if (ACLUserCheckKeyPerm(u, (char *)ri.key, ri.key_len,
CMD_KEY_ACCESS) != ACL_OK)
continue;
count++;
}
raxStop(&ri);

if (count == 0) return NULL;

/* Create the array reply with the list of keys once, then send
* it to all the clients subscribed to this prefix. */
* it to the receiving client. */
char buf[32];
size_t len = ll2string(buf,sizeof(buf),count);
sds proto = sdsempty();
Expand All @@ -587,7 +587,11 @@ sds trackingBuildBroadcastReply(client *c, rax *keys) {
raxStart(&ri,keys);
raxSeek(&ri,"^",NULL,0);
while(raxNext(&ri)) {
if (c && ri.data == c) continue;
if (noloop_client && ri.data == noloop_client)
continue;
if (ACLUserCheckKeyPerm(u, (char *)ri.key, ri.key_len,
CMD_KEY_ACCESS) != ACL_OK)
continue;
len = ll2string(buf,sizeof(buf),ri.key_len);
proto = sdscatlen(proto,"$",1);
proto = sdscatlen(proto,buf,len);
Expand All @@ -599,11 +603,125 @@ sds trackingBuildBroadcastReply(client *c, rax *keys) {
return proto;
}

/* Send the pending BCAST invalidation messages accumulated in a single
* prefix's bcastState to every client subscribed to that prefix, then reset
* bs->keys so only keys accumulated from now on are tracked.
*
* For non-NOLOOP clients the invalidation proto is cached per distinct
* ACL user pointer so that ACLUserCheckKeyPerm is called O(U*K) times
* instead of O(C*K) (U = distinct users, C = clients, K = keys). */
static void trackingBcastInvalidationsForPrefix(bcastState *bs) {
if (raxSize(bs->keys) == 0) return;

raxIterator ri;

/* Per-user proto cache. Key: user * pointer (identity),
* value: sds proto (may be NULL for users whose keys are all
* filtered out by ACL). */
dictType dt = { .hashFunction = dictPtrHash };
dict *user_cache = dictCreate(&dt);

/* Send this array of keys to every client in the list. */
raxStart(&ri,bs->clients);
raxSeek(&ri,"^",NULL,0);
while(raxNext(&ri)) {
client *c;
memcpy(&c,ri.key,sizeof(c));

if (c->flags & CLIENT_TRACKING_NOLOOP) {
sds proto = trackingBuildBroadcastReply(c->user, c, bs->keys);
if (proto) {
sendTrackingMessage(c,proto,sdslen(proto),1);
sdsfree(proto);
}
} else {
dictEntry *existing;
dictEntry *de = dictAddRaw(user_cache, c->user, &existing);
if (de != NULL) {
sds proto = trackingBuildBroadcastReply(c->user, NULL,
bs->keys);
dictSetVal(user_cache, de, proto);
} else {
de = existing;
}
void *cached = dictGetVal(de);
if (cached)
sendTrackingMessage(c,(char*)cached,sdslen((sds)cached),1);
}
}
raxStop(&ri);

/* Free all cached protos. */
dictIterator *cache_di = dictGetIterator(user_cache);
dictEntry *de;
while ((de = dictNext(cache_di)) != NULL) {
sds proto = dictGetVal(de);
if (proto) sdsfree(proto);
}
dictReleaseIterator(cache_di);
dictRelease(user_cache);

/* Clean up: we can remove everything from this state, because we
* want to only track the new keys that will be accumulated starting
* from now. */
raxFree(bs->keys);
bs->keys = raxNew();
}

/* Return 1 if at least one client subscribed to 'bs' is authenticated as
* user 'u', 0 otherwise. */
static int bcastStateHasUser(bcastState *bs, user *u) {
raxIterator ri;
raxStart(&ri,bs->clients);
raxSeek(&ri,"^",NULL,0);
while(raxNext(&ri)) {
client *c;
memcpy(&c,ri.key,sizeof(c));
if (c->user == u) {
raxStop(&ri);
return 1;
}
}
raxStop(&ri);
return 0;
}

/* Flush the pending BCAST invalidation messages for every prefix that client
* 'c' subscribes to, so the keys accumulated so far are delivered under c's
* CURRENT ACL identity.
*
* This must be called BEFORE c->user is changed (e.g. on re-AUTH). Otherwise
* beforeSleep would re-filter the already-accumulated keys by the new
* (possibly stricter) permissions and drop invalidations for keys the client
* could previously read. No-op if 'c' is not a BCAST tracking client. */
void trackingBroadcastFlushClientPrefixes(client *c) {
if (!(c->flags & CLIENT_TRACKING_BCAST)) return;
if (c->client_tracking_prefixes == NULL) return;
if (TrackingTable == NULL || !server.tracking_clients) return;

raxIterator ri;
raxStart(&ri,c->client_tracking_prefixes);
raxSeek(&ri,"^",NULL,0);
while(raxNext(&ri)) {
void *result;
int found = raxFind(PrefixTable,ri.key,ri.key_len,&result);
serverAssert(found);
trackingBcastInvalidationsForPrefix(result);
}
raxStop(&ri);
}

/* This function will run the prefixes of clients in BCAST mode and
* keys that were modified about each prefix, and will send the
* notifications to each client in each prefix. */
void trackingBroadcastInvalidationMessages(void) {
raxIterator ri, ri2;
* notifications to each client in each prefix.
*
* If 'u' is non-NULL, only prefixes that have at least one client
* authenticated as 'u' are flushed. This is used to deliver pending
* invalidations under the old identity before an in-place ACL change to 'u'
* would otherwise cause beforeSleep to re-filter them by the new permissions.
* Passing NULL flushes every prefix. */
void trackingBroadcastInvalidationMessages(user *u) {
raxIterator ri;

/* Return ASAP if there is nothing to do here. */
if (TrackingTable == NULL || !server.tracking_clients) return;
Expand All @@ -614,38 +732,8 @@ void trackingBroadcastInvalidationMessages(void) {
/* For each prefix... */
while(raxNext(&ri)) {
bcastState *bs = ri.data;

if (raxSize(bs->keys)) {
/* Generate the common protocol for all the clients that are
* not using the NOLOOP option. */
sds proto = trackingBuildBroadcastReply(NULL,bs->keys);

/* Send this array of keys to every client in the list. */
raxStart(&ri2,bs->clients);
raxSeek(&ri2,"^",NULL,0);
while(raxNext(&ri2)) {
client *c;
memcpy(&c,ri2.key,sizeof(c));
if (c->flags & CLIENT_TRACKING_NOLOOP) {
/* This client may have certain keys excluded. */
sds adhoc = trackingBuildBroadcastReply(c,bs->keys);
if (adhoc) {
sendTrackingMessage(c,adhoc,sdslen(adhoc),1);
sdsfree(adhoc);
}
} else {
sendTrackingMessage(c,proto,sdslen(proto),1);
}
}
raxStop(&ri2);

/* Clean up: we can remove everything from this state, because we
* want to only track the new keys that will be accumulated starting
* from now. */
sdsfree(proto);
}
raxFree(bs->keys);
bs->keys = raxNew();
if (u == NULL || bcastStateHasUser(bs, u))
trackingBcastInvalidationsForPrefix(bs);
}
raxStop(&ri);
}
Expand Down
Loading
Loading