-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_hopper_adr.py
More file actions
35 lines (28 loc) · 1.27 KB
/
train_hopper_adr.py
File metadata and controls
35 lines (28 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from env.custom_hopper_adr import *
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from automatic_domain_randomization import ADR, ADRCallback
from stable_baselines3.common.monitor import Monitor
from util import plot_train_results
import os
if __name__ == "__main__":
env_id = "CustomHopper-source-adr-v0"
env = gym.make(env_id)
log_dir = f"./tmp/gym/train/{env_id}"
os.makedirs(log_dir, exist_ok=True)
#Initial ADR parameters
min_max_bounds = [(1, 10) for _ in env.get_parameters()]
masses_bounds = [(0.95*mass, 1.05*mass) for mass in env.get_parameters()]
thresholds = (550, 1150)
delta = 0.1 #Update step size
m = 10 #Buffer size
fixed_torso_mass = env.get_parameters()[0]
env = Monitor(env, log_dir)
adr_env = gym.make(env_id)
adr_env = DummyVecEnv([lambda: adr_env])
adr = ADR(masses_bounds, thresholds, delta, m, min_max_bounds, adr_env, fixed_torso_mass)
adr_callback = ADRCallback(adr, env, f'entropy_log_hopper_adr.csv')
model = PPO("MlpPolicy", env, verbose=0)
model.learn(total_timesteps=1_000_000, callback=adr_callback, progress_bar=True)
model.save(f"models/PPO_ADR_{env_id}")
plot_train_results(log_dir, env_id)