-
Notifications
You must be signed in to change notification settings - Fork 323
Closed
Description
class EpisodeWrapper(Wrapper):
"""Maintains episode step count and sets done at episode end."""
def __init__(self, env: Env, episode_length: int, action_repeat: int):
super().__init__(env)
self.episode_length = episode_length
self.action_repeat = action_repeat
def reset(self, rng: jax.Array) -> State:
state = self.env.reset(rng)
state.info['steps'] = jp.zeros(rng.shape[:-1])
state.info['truncation'] = jp.zeros(rng.shape[:-1])
# Keep separate record of episode done as state.info['done'] can be erased
# by AutoResetWrapper
state.info['episode_done'] = jp.zeros(rng.shape[:-1])
episode_metrics = dict()
episode_metrics['sum_reward'] = jp.zeros(rng.shape[:-1])
episode_metrics['length'] = jp.zeros(rng.shape[:-1])
for metric_name in state.metrics.keys():
episode_metrics[metric_name] = jp.zeros(rng.shape[:-1])
state.info['episode_metrics'] = episode_metrics
return state
def step(self, state: State, action: jax.Array) -> State:
def f(state, _):
nstate = self.env.step(state, action)
return nstate, nstate.reward
state, rewards = jax.lax.scan(f, state, (), self.action_repeat)
state = state.replace(reward=jp.sum(rewards, axis=0))
steps = state.info['steps'] + self.action_repeat
one = jp.ones_like(state.done)
zero = jp.zeros_like(state.done)
episode_length = jp.array(self.episode_length, dtype=jp.int32)
done = jp.where(steps >= episode_length, one, state.done)
state.info['truncation'] = jp.where(
steps >= episode_length, 1 - state.done, zero
)
state.info['steps'] = steps
# Aggregate state metrics into episode metrics
prev_done = state.info['episode_done']
state.info['episode_metrics']['sum_reward'] += jp.sum(rewards, axis=0)
state.info['episode_metrics']['sum_reward'] *= (1 - prev_done) <----------------- bug
state.info['episode_metrics']['length'] += self.action_repeat
state.info['episode_metrics']['length'] *= (1 - prev_done)
for metric_name in state.metrics.keys():
if metric_name != 'reward':
state.info['episode_metrics'][metric_name] += state.metrics[metric_name]
state.info['episode_metrics'][metric_name] *= (1 - prev_done)
state.info['episode_done'] = done
return state.replace(done=done)
when episode ends new episode starts the reward of first step is being erased because of prev done. one way to solve this problem is to interchange the lines
state.info['episode_metrics']['length'] += self.action_repeat
state.info['episode_metrics']['sum_reward'] *= (1 - prev_done) <----------------- it was bug
same for all metrics like length and others
i have created a PR please accept it
Metadata
Metadata
Assignees
Labels
No labels