-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpendulum_data_genration.py
More file actions
53 lines (41 loc) · 1.52 KB
/
pendulum_data_genration.py
File metadata and controls
53 lines (41 loc) · 1.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import torch
import random
import gymnasium as gym
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def generate_data(env, n_traj, traj_len):
n_states = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]
n = 0
n_tot = 0
traj_states = torch.zeros((n_traj, traj_len, n_states)).to(device)
traj_actions = torch.zeros((n_traj, traj_len)).to(device)
traj_rewards = torch.zeros((n_traj, traj_len)).to(device)
while n < n_traj:
state, info = env.reset()
state = torch.tensor(state).to(device)
for i in range(traj_len):
action = env.action_space.sample()
state_next, reward, terminated, truncated, info = env.step(action)
traj_actions[n, i] = torch.tensor(action).to(device)
traj_rewards[n, i] = reward
traj_states[n, i, :] = state
state = torch.tensor(state_next).to(device)
if (terminated or truncated) and (i < traj_len - 1):
n = n - 1
break
env.close()
n = n + 1
n_tot = n_tot + 1
if n%5000 == 0:
print("Number of trajectories generated:",n)
print("Number of total trajectories:",n_tot)
torch.save(traj_states, "data/pendulum_states.pt")
torch.save(traj_actions, "data/pendulum_actions.pt")
def main():
env = gym.make("Pendulum-v1")
n_traj = int(1e5)
traj_len = 200
generate_data(env, n_traj, traj_len)
if __name__ == "__main__":
main()