From b6eb1294b0463ae0c61421ad5b260e8837edf7e2 Mon Sep 17 00:00:00 2001 From: Kaiyotech <93724202+Kaiyotech@users.noreply.github.com> Date: Thu, 15 Sep 2022 12:09:44 -0400 Subject: [PATCH] batching redis pushes on worker into pipeline --- .../redis/redis_rollout_worker.py | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/rocket_learn/rollout_generator/redis/redis_rollout_worker.py b/rocket_learn/rollout_generator/redis/redis_rollout_worker.py index a29bfd3..df487cc 100644 --- a/rocket_learn/rollout_generator/redis/redis_rollout_worker.py +++ b/rocket_learn/rollout_generator/redis/redis_rollout_worker.py @@ -51,7 +51,10 @@ def __init__(self, redis: Redis, name: str, match: Match, send_obs=True, scoreboard=None, pretrained_agents=None, human_agent=None, force_paging=False, auto_minimize=True, local_cache_name=None, - gamemode_weights=None,): + gamemode_weights=None, + batch_mode=False, + step_size=100_000, + ): # TODO model or config+params so workers can recreate just from redis connection? self.redis = redis self.name = name @@ -87,6 +90,11 @@ def __init__(self, redis: Redis, name: str, match: Match, self.uuid = str(uuid4()) self.redis.rpush(WORKER_IDS, self.uuid) + self.batch_mode = batch_mode + self.step_size_limit = min(step_size / 20, 25_000) + if self.batch_mode: + self.red_pipe = self.redis.pipeline() + self.step_last_send = 0 # currently doesn't rebuild, if the old is there, reuse it. if self.local_cache_name: @@ -318,7 +326,7 @@ def run(self): # Mimics Thread if not self.streamer_mode: print(post_stats) - if not self.streamer_mode: + if not self.streamer_mode and not self.batch_mode: rollout_data = encode_buffers(rollouts, return_obs=self.send_obs, return_states=self.send_gamestates, @@ -345,6 +353,26 @@ def send(): # t.start() # time.sleep(0.01) + elif not self.streamer_mode and self.batch_mode: + + rollout_data = encode_buffers(rollouts, + return_obs=self.send_obs, + return_states=self.send_gamestates, + return_rewards=True) + rollout_bytes = _serialize((rollout_data, versions, self.uuid, self.name, result, + self.send_obs, self.send_gamestates, True)) + + self.red_pipe.rpush(ROLLOUTS, rollout_bytes) + + # def send(): + if (self.total_steps_generated - self.step_last_send) > self.step_size_limit or \ + len(self.red_pipe) > 100: + n_items = self.red_pipe.execute() + if n_items[-1] >= 10000: + print("Had to limit rollouts. Learner may have have crashed, or is overloaded") + self.redis.ltrim(ROLLOUTS, -100, -1) + self.step_last_send = self.total_steps_generated + def _generate_matchup(self, n_agents, latest_version, pretrained_choice): n_old = 0 if n_agents > 1: