-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathreplaybuffer.py
More file actions
97 lines (80 loc) · 3.43 KB
/
Copy pathreplaybuffer.py
File metadata and controls
97 lines (80 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
'''
With inspiration from
https://github.com/facebookresearch/drqv2/pull/17/files
https://github.com/ikostrikov/jaxrl/blob/haiku/replay_buffer.py
'''
import numpy as np
# from jax import numpy as np
class ReplayBuffer(object):
def __init__(self, obs_dim, act_dim, max_size=int(1e6), discount=.99, nstep=1):
self.index = 0
self.size = 0
self.length = 0
self.max_size = max_size
self.discounts = np.power(discount, np.arange(nstep))
self.nstep = nstep
self.obs = np.zeros((max_size, obs_dim))
self.act = np.zeros((max_size, act_dim))
self.nobs = np.zeros((max_size, obs_dim))
self.rew = np.zeros((2*max_size, 1))
self.done = np.zeros((2*max_size, 1)) # done is 0 when not finished and 1 when done
def add(self, obs, act, rew, next_obs, done):
# self.obs.at[self.index].set(obs)
# self.act.at[self.index].set(act)
# self.rew.at[self.index].set(rew)
# self.nobs.at[self.index].set(next_obs)
# self.done.at[self.index].set(done)
self.obs[self.index] = obs
self.act[self.index] = act
self.rew[self.index] = rew
self.nobs[self.index] = next_obs
self.done[self.index] = done
self.index = (self.index + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
self.length += 1
# def sample(self, batch_size):
# index = np.random.randint(0, self.size-self.nstep, size=batch_size)
# # print(self.obs[index].shape)
# # print(self.act[index].shape)
# # print(self.rew[index].shape)
# # print(self.nobs[index].shape)
# # print(self.done[index].shape)
# # print(c)
# return {
# 'obs' : self.obs[index],
# 'act' : self.act[index],
# 'rew' : self.rew[index],
# 'next_obs': self.nobs[index],
# 'done' : self.done[index]
# }
def sample(self, batch_size):
# half-cheetah, nstep3, this : ~ 200
# half-cheetah, nstep3, cpprb: ~ 320
first = np.random.randint(0, self.size-self.nstep, size=batch_size)
last = np.minimum(first + self.nstep, self.size)
mask = np.ones((batch_size, self.nstep+1))
mask[np.arange(batch_size), last-first] = 0
mask = np.cumprod(mask, 1)
mask = mask[:, :self.nstep]
def discounted_sum(f, l):
rews = self.rew[f:l]
discounts = self.discounts[:l-f]
return np.sum(rews * discounts) # this is not correct because it needs to account for the sequence
# adapted from https://stackoverflow.com/a/45152908
dones = self.done[first[:, None] + np.arange(self.nstep)]
print(dones.shape, mask.shape)
dones *= mask[:, :, None]
dones = 1 - (1 - dones).prod(1)
print(self.obs[first].shape)
print(self.act[first].shape)
print(np.vectorize(discounted_sum)(first, last)[:, None].shape)
print(self.nobs[last].shape)
print(dones.shape)
print(c)
return {
'obs' : self.obs[first],
'act' : self.act[first],
'rew' : np.vectorize(discounted_sum)(first, last)[:, None],
'next_obs': self.nobs[last],
'done' : dones
}