diff --git a/environment/agent.py b/environment/agent.py index 013929f..2848d66 100644 --- a/environment/agent.py +++ b/environment/agent.py @@ -1037,12 +1037,18 @@ def train(agent: Agent, if train_logging == TrainLogging.PLOT: plot_results(log_dir) + # Also generate detailed analysis + try: + from user.plot_learning_curve import plot_learning_curve + plot_learning_curve(log_dir) + except Exception as e: + print(f"Note: Could not generate detailed analysis: {e}") ## Run Human vs AI match function import pygame from pygame.locals import QUIT -def run_real_time_match(agent_1: UserInputAgent, agent_2: Agent, max_timesteps=30*90, resolution=CameraResolution.LOW): +def run_real_time_match(agent_1: Agent, agent_2: Agent, max_timesteps=30*900, resolution=CameraResolution.HIGH): pygame.init() pygame.mixer.init() @@ -1084,7 +1090,7 @@ def run_real_time_match(agent_1: UserInputAgent, agent_2: Agent, max_timesteps=3 # platform1 = env.objects["platform1"] #mohamed #stage2 = env.objects["stage2"] background_image = pygame.image.load('environment/assets/map/bg.jpg').convert() - while running and timestep < max_timesteps: + while running and timestep < 2147000000: # Pygame event to handle real-time user input @@ -1134,6 +1140,8 @@ def run_real_time_match(agent_1: UserInputAgent, agent_2: Agent, max_timesteps=3 result = Result.LOSS else: result = Result.DRAW + + print(result) match_stats = MatchStats( match_time=timestep / 30.0, @@ -1145,4 +1153,4 @@ def run_real_time_match(agent_1: UserInputAgent, agent_2: Agent, max_timesteps=3 # Close environment env.close() - return match_stats \ No newline at end of file + return match_stats diff --git a/environment/environment.py b/environment/environment.py index 3bc75f7..3bb1569 100644 --- a/environment/environment.py +++ b/environment/environment.py @@ -876,6 +876,7 @@ def __init__(self, mode: RenderMode=RenderMode.RGB_ARRAY, resolution: CameraReso self.observation_space = self.get_observation_space() self.camera = Camera() + self.camera.zoom = 0.7 # Action Space # WASD @@ -4952,4 +4953,4 @@ def render(self, canvas: pygame.Surface, camera: Camera) -> None: screen_pos = camera.gtp(self.position) screen_pos = (0,0) #canvas.blit(self.frames[self.current_frame_index], screen_pos) - self.draw_image(canvas, self.frames[self.current_frame_index], self.position, 2, camera) \ No newline at end of file + self.draw_image(canvas, self.frames[self.current_frame_index], self.position, 2, camera) diff --git a/user/my_agent.py b/user/my_agent.py index cd6cfce..104afc1 100644 --- a/user/my_agent.py +++ b/user/my_agent.py @@ -31,47 +31,86 @@ class SubmittedAgent(Agent): Input the **file_path** to your agent here for submission! ''' def __init__( - self, - file_path: Optional[str] = None, + self, + *args, + **kwargs ): - super().__init__(file_path) + super().__init__(*args, **kwargs) + self.time = 0 + self.prev_pos = None + self.down = False + self.recover = False - # To run a TTNN model, you must maintain a pointer to the device and can be done by - # uncommmenting the line below to use the device pointer - # self.mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1,1)) + def predict(self, obs): + self.time += 1 + pos = self.obs_helper.get_section(obs, 'player_pos') + opp_pos = self.obs_helper.get_section(obs, 'opponent_pos') + opp_KO = self.obs_helper.get_section(obs, 'opponent_state') in [5, 11] + action = self.act_helper.zeros() + facing = self.obs_helper.get_section(obs, 'player_facing') - def _initialize(self) -> None: - if self.file_path is None: - self.model = PPO("MlpPolicy", self.env, verbose=0) - del self.env - else: - self.model = PPO.load(self.file_path) + opp_grounded = self.obs_helper.get_section(obs, 'opponent_grounded') + opp_state = self.obs_helper.get_section(obs, 'opponent_state') + opp_move_type = self.obs_helper.get_section(obs, 'opponent_move_type') - # To run the sample TTNN model during inference, you can uncomment the 5 lines below: - # This assumes that your self.model.policy has the MLPPolicy architecture defined in `train_agent.py` or `my_agent_tt.py` - # mlp_state_dict = self.model.policy.features_extractor.model.state_dict() - # self.tt_model = TTMLPPolicy(mlp_state_dict, self.mesh_device) - # self.model.policy.features_extractor.model = self.tt_model - # self.model.policy.vf_features_extractor.model = self.tt_model - # self.model.policy.pi_features_extractor.model = self.tt_model + is_opponent_spamming = opp_grounded == 1 and opp_state == 8 and opp_move_type > 0 - def _gdown(self) -> str: - data_path = "rl-model.zip" - if not os.path.isfile(data_path): - print(f"Downloading {data_path}...") - # Place a link to your PUBLIC model data here. This is where we will download it from on the tournament server. - url = "https://drive.google.com/file/d/1JIokiBOrOClh8piclbMlpEEs6mj3H1HJ/view?usp=sharing" - gdown.download(url, output=data_path, fuzzy=True) - return data_path + spawners = self.env.get_spawner_info() - def predict(self, obs): - action, _ = self.model.predict(obs) - return action + # pick up a weapon if near + ''' + if self.obs_helper.get_section(obs, 'player_weapon_type') == 0: + for w in spawners: + if euclid(pos, w[1]) < 3: + action = self.act_helper.press_keys(['h'], action) + ''' + + # emote for fun + if self.time == 10 or self.obs_helper.get_section(obs, 'opponent_stocks') == 0: + action = self.act_helper.press_keys(['g'], action) + return action + + if self.prev_pos is not None: + self.down = (pos[1] - self.prev_pos[1]) > 0 + self.prev_pos = pos - def save(self, file_path: str) -> None: - self.model.save(file_path) + self.recover = False + if pos[0] < -4.8: + action = self.act_helper.press_keys(['d'], action) + self.recover = True + elif pos[0] > -4.2 and pos[0] < 0: + action = self.act_helper.press_keys(['a'], action) + self.recover = True + elif pos[0] > 0 and pos[0] < 4.2: + action = self.act_helper.press_keys(['d'], action) + self.recover = True + elif pos[0] > 4.8: + action = self.act_helper.press_keys(['a'], action) + self.recover = True + + # Jump if falling + if pos[1] > -5 and (self.down or (self.obs_helper.get_section(obs, 'player_grounded') == 1) and not is_opponent_spamming): + if self.time % 10 == 0: + action = self.act_helper.press_keys(['space'], action) + if self.recover and self.obs_helper.get_section(obs, 'player_grounded') == 0 and self.obs_helper.get_section(obs, 'player_jumps_left') == 0 and self.obs_helper.get_section(obs, 'player_recoveries_left') == 1 and self.time % 2 == 0: + action = self.act_helper.press_keys(['k'], action) + + + if not self.recover: + if opp_pos[0] > pos[0]: + action = self.act_helper.press_keys(['d'], action) + elif opp_pos[0] < pos[0]: + action = self.act_helper.press_keys(['a'], action) + + + # Attack if near + if not self.recover and abs(pos[0] - opp_pos[0]) < 0.5 and pos[1] < opp_pos[1]: + action = self.act_helper.press_keys(['s'], action) + action = self.act_helper.press_keys(['k'], action) + elif not self.recover and euclid(pos, opp_pos) < 4: + action = self.act_helper.press_keys(['j'], action) + + return action - # If modifying the number of models (or training in general), modify this - def learn(self, env, total_timesteps, log_interval: int = 4): - self.model.set_env(env) - self.model.learn(total_timesteps=total_timesteps, log_interval=log_interval) \ No newline at end of file +def euclid (a, b): + return (a[0] - b[0])**2 + (a[1] - b[1])**2 diff --git a/user/opp_agent.py b/user/opp_agent.py new file mode 100644 index 0000000..d5ad42a --- /dev/null +++ b/user/opp_agent.py @@ -0,0 +1,453 @@ +# # SUBMISSION: Agent +# This will be the Agent class we run in the 1v1. We've started you off with a functioning RL agent (`SB3Agent(Agent)`) and if-statement agent (`BasedAgent(Agent)`). Feel free to copy either to `SubmittedAgent(Agent)` then begin modifying. +# +# Requirements: +# - Your submission **MUST** be of type `SubmittedAgent(Agent)` +# - Any instantiated classes **MUST** be defined within and below this code block. +# +# Remember, your agent can be either machine learning, OR if-statement based. I've seen many successful agents arising purely from if-statements - give them a shot as well, if ML is too complicated at first!! +# +# Also PLEASE ask us questions in the Discord server if any of the API is confusing. We'd be more than happy to clarify and get the team on the right track. +# Requirements: +# - **DO NOT** import any modules beyond the following code block. They will not be parsed and may cause your submission to fail validation. +# - Only write imports that have not been used above this code block +# - Only write imports that are from libraries listed here +# We're using PPO by default, but feel free to experiment with other Stable-Baselines 3 algorithms! + +import os +import gdown +from typing import Optional +import numpy as np +import random +import math +from environment.agent import Agent +from stable_baselines3 import PPO, A2C # Sample RL Algo imports +from sb3_contrib import RecurrentPPO # Importing an LSTM + +# To run the sample TTNN model, you can uncomment the 2 lines below: +# import ttnn +# from user.my_agent_tt import TTMLPPolicy + + +class SubmittedAgent(Agent): + ''' + Input the **file_path** to your agent here for submission! + ''' + def __init__( + self, + file_path: Optional[str] = None, + ): + super().__init__(file_path) + + # To run a TTNN model, you must maintain a pointer to the device and can be done by + # uncommmenting the line below to use the device pointer + # self.mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1,1)) + + def _gdown(self) -> str: + data_path = "rl-model.zip" + if not os.path.isfile(data_path): + print(f"Downloading {data_path}...") + # Place a link to your PUBLIC model data here. This is where we will download it from on the tournament server. + url = "https://drive.google.com/file/d/1JIokiBOrOClh8piclbMlpEEs6mj3H1HJ/view?usp=sharing" + gdown.download(url, output=data_path, fuzzy=True) + return data_path + + def _initialize(self) -> None: + + self.time = 0 + self.lastJumped = 0 + self.phase = "aggressive" + self.platforms = [[2.5, 6.5, 0.4], [-6.5, -2.5, 2.4]] + + self.state_mapping = { + 'WalkingState': 0, + 'StandingState': 1, + 'TurnaroundState': 2, + 'AirTurnaroundState': 3, + 'SprintingState': 4, + 'StunState': 5, + 'InAirState': 6, + 'DodgeState': 7, + 'AttackState': 8, + 'DashState': 9, + 'BackDashState': 10, + 'KOState': 11, + 'TauntState': 12, + } + + self.damage_last_time = 0 + self.state_reverse = {v: k for k, v in self.state_mapping.items()} + self.can_cross = False + self.last_dodged = 0 + self.used_attack = False + self.last_platform = None + + self.weapons_data = { + 0: [ + {"keys": ["j"], "cover": [0, 5], "type": "ground", "range": "close"}, + {"keys": ["k"], "cover": [0, 2, 3, 5], "type": "ground", "range": "close"}, + {"keys": ["d", "j"], "cover": [0, 3, 5], "type": "ground", "range": "lunge"}, + {"keys": ["s", "j"], "cover": [0, 5], "type": "ground", "range": "far"}, + {"keys": ["d", "k"], "cover": [0, 3, 5], "type": "ground", "range": "lunge"}, + {"keys": ["s", "k"], "cover": [0, 4, 5], "type": "ground", "range": "close"}, + + {"keys": ["j"], "cover": [0,5], "type": "aerial", "range": "close"}, + {"keys": ["d", "j"], "cover": [0, 5], "type": "aerial", "range": "far"}, + {"keys": ["k"], "cover": [0, 2, 3, 5], "type": "aerial", "range": "close"}, + {"keys": ["s", "k"], "cover": [5], "type": "aerial", "range": "lunge"}, + {"keys": ["s", "j"], "cover": [8], "type": "aerial", "range": "lunge"}, + ], # Fist + + # HAMMER + 2: [ + {"keys": ["j"], "cover": [5], "type": "ground", "range": "close"}, # short swing, quick poke + {"keys": ["k"], "cover": [2, 3, 5], "type": "ground", "range": "close"}, # big lunge, heavy hit + {"keys": ["d", "j"], "cover": [5], "type": "ground", "range": "lunge"}, # step-in swipe (right) + # {"keys": ["d", "k"], "cover": [2, 3, 5], "type": "ground", "range": "far"}, # heavy lunge (right) + {"keys": ["s", "j"], "cover": [5], "type": "ground", "range": "far"}, # low sweep + {"keys": ["j"], "cover": [1, 2, 3], "type": "aerial", "range": "far"}, # quick mid-air swing + {"keys": ["s", "j"], "cover": [7], "type": "aerial", "range": "far"}, + {"keys": ["d","j"], "cover": [5], "type": "aerial", "range": "far"}, + {"keys": ["k"], "cover": [2], "type": "aerial", "range": "close"}, + {"keys": ["s", "k"], "cover": [7], "type": "aerial", "range": "lunge"}, + + # big aerial smash + ], + # SPEAR + 1: [ + {"keys": ["j"], "cover": [5], "type": "ground", "range": "lunge"}, # short swing, quick poke + {"keys": ["k"], "cover": [1, 2, 3], "type": "ground", "range": "lunge"}, # big lunge, heavy hit + {"keys": ["d", "j"], "cover": [5, 3], "type": "ground", "range": "close"}, + {"keys": ["d", "j"], "cover": [5, 3], "type": "ground", "range": "far"}, # step-in swipe (right) + {"keys": ["s", "j"], "cover": [8], "type": "ground", "range": "lunge"}, # low sweep + {"keys": ["s", "k"], "cover": [3, 5], "type": "ground", "range": "close", "super_safe": True}, # downward spike + {"keys": ["j"], "cover": [1, 2, 3, 4, 5, 6, 7, 8], "type": "aerial", "range": "close"}, # quick mid-air swing + {"keys": ["d", "j"], "cover": [5], "type": "aerial", "range": "far"}, # quick mid-air swing + {"keys": ["s", "j"], "cover": [7], "type": "aerial", "range": "far"}, # quick mid-air swing + {"keys": ["k"], "cover": [1, 2, 3], "type": "aerial", "range": "lunge"}, # big aerial smash + {"keys": ["s" ,"k"], "cover": [6, 7, 8], "type": "aerial", "range": "lunge"}, + ] + } + + self.weapons_range = { + 0: 0, + 1: 0, # 0.6 + 2: 0 # 0.3 + } + + def predict(self, obs): + self.time += 1 + pos = self.obs_helper.get_section(obs, 'player_pos') + player_vel = self.obs_helper.get_section(obs, 'player_vel') + opp_pos = self.obs_helper.get_section(obs, 'opponent_pos') + opp_vel = self.obs_helper.get_section(obs, 'opponent_vel') + + + + opp_pos[0] += opp_vel[0] * 1/30 + opp_pos[1] += opp_vel[1]* 1/30 + + # if self.time % 10 == 0: + # print(opp_vel / 60) + + + opp_KO = self.obs_helper.get_section(obs, 'opponent_state') in [5, 11] + action = self.act_helper.zeros() + readible_obs = { + "pos": obs[0:2], + "vel": obs[2:4], + "facing": ( + "right" + if self.obs_helper.get_section(obs, "player_facing")[0] > 0.5 + else "left" + ), + "grounded": self.obs_helper.get_section(obs, "player_grounded")[0] > 0.5, + "aerial": self.obs_helper.get_section(obs, "player_aerial")[0] > 0.5, + "damage": self.obs_helper.get_section(obs, "player_damage")[0], + "jumps_left": int(self.obs_helper.get_section(obs, "player_jumps_left")[0]), + "opp_jumps_left": int(self.obs_helper.get_section(obs, "opponent_jumps_left")[0]), + "stun_frames": self.obs_helper.get_section(obs, 'player_stun_frames')[0], + "state": self.state_reverse.get( + int(self.obs_helper.get_section(obs, 'player_state')[0]), + "UnknownState" + ), + "opp_state": self.state_reverse.get( + int(self.obs_helper.get_section(obs, 'opponent_state')[0]), + "UnknownState" + ), + "weapon_type": int(self.obs_helper.get_section(obs, "player_weapon_type")[0]), + "opp_weapon_type": int(self.obs_helper.get_section(obs, "opponent_weapon_type")[0]), + } + moving_platform = self.obs_helper.get_section(obs, "player_moving_platform_pos") + moving_platform = [float(x) for x in moving_platform] + all_platforms = self.platforms + # all_platforms = self.platforms + [ + # [moving_platform[0] - 0.2, moving_platform[0] + 0.2, moving_platform[1]] + # ] + keys = [] + targetX = 0 + targetY = 0 + MARGIN = 0.3 + + spawner_list = [ + self.obs_helper.get_section(obs, "player_spawner_1"), + self.obs_helper.get_section(obs, "player_spawner_2"), + self.obs_helper.get_section(obs, "player_spawner_3"), + self.obs_helper.get_section(obs, "player_spawner_4"), + ] + # return action + # if self.time % 10 == 0: + # print(readible_obs["weapon_type"]) + # return action + # Find nearest active spawner + nearest_spawner = None + nearest_distance = float('inf') + + for spawner in spawner_list: + x, y, spawner_type = spawner + + if spawner_type == 0: # inactive + continue + + dx = x - pos[0] + dy = y - pos[1] + dist = math.sqrt(dx**2 + dy**2) + + if dist < nearest_distance: + nearest_distance = dist + nearest_spawner = spawner + + if nearest_spawner is not None and readible_obs["weapon_type"] == 0: + targetX = nearest_spawner[0] + targetY = nearest_spawner[1] + if readible_obs["jumps_left"] >= 2 or (readible_obs["jumps_left"] == 0 and readible_obs["grounded"]): + self.can_cross = True + self.phase = "weapon_grab" + + if nearest_distance < 0.5: + self.phase = "aggressive" + keys.append("h") + self.can_cross = False + + if nearest_spawner is None and readible_obs["weapon_type"] == 0 and readible_obs["opp_weapon_type"] != 0: + self.phase = "flee" + self.can_cross = True + + safe = False + curr_platform = None + for platform in all_platforms: + if pos[0] > platform[0] and pos[0] < platform[1] and pos[1] < platform[2] + 0.2: + safe = True + curr_platform = platform + self.last_platform = curr_platform + + super_safe = False + for platform in all_platforms: + if pos[0] > platform[0] + 1.5 and pos[0] < platform[1] - 1.5 and pos[1] < platform[2] + 0.2: + super_safe = True + + + opp_platform = None + for platform in all_platforms: + if opp_pos[0] > platform[0] and opp_pos[0] < platform[1] and opp_pos[1] < platform[2] + 0.2: + opp_platform = platform + + dx = opp_pos[0] - pos[0] + dy = opp_pos[1] - pos[1] + distance_to_opp = math.sqrt(dx**2 + dy**2) + + # return action + # GET THE SECTOR + leeway = 0.3 + sector = None + + # vertical thresholds + if abs(dx) <= leeway and dy < -leeway: + sector = 2 # above + elif abs(dx) <= leeway and dy > leeway: + sector = 7 # below + elif abs(dy) <= leeway and dx < -leeway: + sector = 4 # left + elif abs(dy) <= leeway and dx > leeway: + sector = 5 # right + elif dx < -leeway and dy < -leeway: + sector = 1 # top-left + elif dx > leeway and dy < -leeway: + sector = 3 # top-right + elif dx < -leeway and dy > leeway: + sector = 6 # bottom-left + elif dx > leeway and dy > leeway: + sector = 8 # bottom-right + else: + sector = 0 # overlapping / same position + # if self.time % 10 == 0: + # print(sector) + + # return action + + + distances = [] + for platform in all_platforms: + x_min, x_max, y_level = platform + if pos[0] < x_min: + dist = x_min - pos[0] + elif pos[0] > x_max: + dist = pos[0] - x_max + else: + dist = 0 # directly above platform + + if y_level < pos[1]: + dist += (pos[1] - y_level) * 1.8 + + distances.append(dist) + + # Find the nearest platform + nearest_index = distances.index(min(distances)) + nearest_platform = all_platforms[nearest_index] + x_min_nearest, x_max_nearest, y_level_nearest = nearest_platform + do_not_move = False + if dx > 0: + opp_direction_x = "right" + else: + opp_direction_x = "left" + + if (curr_platform and opp_platform and curr_platform == opp_platform and self.phase == "passive") or self.damage_last_time < readible_obs["damage"]: #self.damage_last_time < readible_obs["damage"] + self.phase = "aggressive" + self.damage_last_time = readible_obs["damage"] + self.can_cross = False + + if not safe and not self.can_cross: + if pos[0] < x_min_nearest: + targetX = x_min_nearest + elif pos[0] > x_max_nearest: + targetX = x_max_nearest + + targetY = y_level_nearest + + if self.phase == "passive" and safe: + if opp_direction_x == "left": + targetX = max(curr_platform[0], opp_pos[0]) + else: + targetX = min(curr_platform[1], opp_pos[0]) + + targetY = curr_platform[2] + elif self.phase == "flee": + flee_platform = all_platforms[0] + if opp_platform is None: + if curr_platform is not None: + flee_platform = curr_platform + elif opp_platform == all_platforms[0]: + flee_platform = all_platforms[1] + elif opp_platform == all_platforms[0]: + flee_platform = all_platforms[0] + targetX = (flee_platform[1] + flee_platform[0]) / 2 + targetY = flee_platform[2] + elif self.phase == "aggressive" and safe and not self.used_attack: # opp_direction_x == readible_obs["facing"] + curr_weapon = readible_obs["weapon_type"] + facing = readible_obs["facing"] + grounded = readible_obs["grounded"] + attack_state = "ground" if grounded else "aerial" + + curr_weapon_range = self.weapons_range[curr_weapon] + if distance_to_opp < 1 + curr_weapon_range: + attack_type = "close" + elif distance_to_opp < 1.5 + curr_weapon_range: + attack_type = "far" + elif distance_to_opp < 2.2 + curr_weapon_range: + attack_type = "lunge" + else: + attack_type = "out_of_range" + + if attack_type != "out_of_range": + + # Flip mapping for facing left + sector_flip = {1: 3, 3: 1, 4: 5, 5: 4, 6: 8, 8: 6, 2: 2, 7: 7, 0: 0} + eff_sector = sector_flip[sector] if facing == "left" else sector + + # Collect all valid moves + valid_moves = [] + for move in self.weapons_data[curr_weapon]: + if move["type"] != attack_state: + continue + if move["range"] != attack_type: + continue + if eff_sector not in move["cover"]: + continue + if move.get("super_safe") and not super_safe: + continue + valid_moves.append(move) + + # Pick one at random if any valid moves exist + if valid_moves: + move = random.choice(valid_moves) + keys_to_press = move["keys"].copy() + + # Flip horizontal keys if facing left + if facing == "left": + keys_to_press = ["a" if k=="d" else "d" if k=="a" else k for k in keys_to_press] + + # Press keys + for k in keys_to_press: + keys.append(k) + do_not_move = True + self.used_attack = True + elif self.used_attack: + self.used_attack = False + + # TRAVEL HERE + if self.phase == "aggressive" and opp_platform is not None and self.last_platform is not None and self.last_platform != opp_platform \ + and not (readible_obs["aerial"] and readible_obs["jumps_left"] == 0 and pos[1] >= opp_platform[2]): + targetX = (opp_platform[1] + opp_platform[0]) / 2 + targetY = opp_platform[2] + + elif not safe: + MARGIN = 0 + self.can_cross = False + elif self.phase == "aggressive" and not opp_KO: + MARGIN = 0 + + + targetX = opp_pos[0] + targetY = opp_pos[1] + self.can_cross = False + + if not do_not_move: + + if pos[1] > moving_platform[1] and pos[0] > moving_platform[0] - 1.5 and pos[0] < moving_platform[0] + 1.5: + keys.append('a') + elif pos[0] < targetX - MARGIN: + keys.append('d') + elif pos[0] >= targetX + MARGIN: + keys.append('a') + + # print(player_vel) + if pos[1] > targetY + MARGIN and self.time - self.lastJumped > 10: + if not safe: + self.lastJumped = self.time + keys.append('space') + if readible_obs["jumps_left"] == 0 and readible_obs["aerial"]: + keys.append('w') + keys.append('k') + elif random.randint(0, 100) > 95: + self.lastJumped = self.time + keys.append('space') + + if distance_to_opp < 1 and readible_obs["opp_state"] == "AttackState" and self.time - self.last_dodged > 50: + self.last_dodged = self.time + keys = [] + keys.append("l") + + + action = self.act_helper.press_keys(keys, action) + return action + + + def save(self, file_path: str) -> None: + self.model.save(file_path) + + # If modifying the number of models (or training in general), modify this + def learn(self, env, total_timesteps, log_interval: int = 4): + self.model.set_env(env) + self.model.learn(total_timesteps=total_timesteps, log_interval=log_interval) + + \ No newline at end of file diff --git a/user/plot_learning_curve.py b/user/plot_learning_curve.py new file mode 100644 index 0000000..55653a0 --- /dev/null +++ b/user/plot_learning_curve.py @@ -0,0 +1,178 @@ +""" +Generate learning curve visualization from training data. +Run this after training to analyze progress and identify areas for improvement. +""" + +import os +import sys +import numpy as np +import matplotlib.pyplot as plt +from stable_baselines3.common.results_plotter import load_results, ts2xy + +def plot_learning_curve(experiment_path): + """ + Plot learning curve with improved visualization and analysis. + + Args: + experiment_path: Path to checkpoint folder (e.g., 'checkpoints/experiment_aggressive_v3') + """ + if not os.path.exists(experiment_path): + print(f"Error: Path {experiment_path} does not exist!") + return + + monitor_file = os.path.join(experiment_path, 'monitor.csv') + if not os.path.exists(monitor_file): + print(f"Error: monitor.csv not found in {experiment_path}!") + return + + # Load data + x, y = ts2xy(load_results(experiment_path), "timesteps") + + if len(x) == 0 or len(y) == 0: + print("Error: No data found in monitor.csv!") + return + + # Smoothing window + window_size = 100 + weights = np.repeat(1.0, window_size) / window_size + y_smoothed = np.convolve(y, weights, "valid") + x_smoothed = x[len(x) - len(y_smoothed):] + + # Create figure with subplots + fig = plt.figure(figsize=(16, 10)) + + # Main learning curve + ax1 = plt.subplot(2, 2, 1) + ax1.plot(x_smoothed, y_smoothed, linewidth=2, color='blue', label='Smoothed Reward') + ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5, linewidth=1) + ax1.set_xlabel("Timesteps", fontsize=11) + ax1.set_ylabel("Episode Reward", fontsize=11) + ax1.set_title("Learning Curve (Smoothed, Window=100)", fontsize=12, fontweight='bold') + ax1.grid(True, alpha=0.3) + ax1.legend() + ax1.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M')) + + # Recent performance (last 1M steps) + ax2 = plt.subplot(2, 2, 2) + if len(x_smoothed) > 0: + last_1m = x_smoothed[-1] - 1_000_000 + mask = x_smoothed >= max(0, last_1m) + if np.any(mask): + ax2.plot(x_smoothed[mask], y_smoothed[mask], linewidth=2, color='green') + ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + ax2.set_xlabel("Timesteps", fontsize=11) + ax2.set_ylabel("Episode Reward", fontsize=11) + ax2.set_title("Recent Performance (Last 1M Steps)", fontsize=12, fontweight='bold') + ax2.grid(True, alpha=0.3) + ax2.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M')) + + # Statistics + ax3 = plt.subplot(2, 2, 3) + ax3.axis('off') + + # Calculate statistics + total_steps = x[-1] if len(x) > 0 else 0 + final_avg_reward = np.mean(y_smoothed[-1000:]) if len(y_smoothed) >= 1000 else np.mean(y_smoothed) + max_reward = np.max(y_smoothed) + min_reward = np.min(y_smoothed) + recent_trend = np.mean(y_smoothed[-1000:]) - np.mean(y_smoothed[-5000:-1000]) if len(y_smoothed) >= 5000 else 0 + + stats_text = f""" +TRAINING STATISTICS +═══════════════════════════════════════ +Total Timesteps: {total_steps/1e6:.2f}M +Final Average Reward: {final_avg_reward:.2f} +Maximum Reward: {max_reward:.2f} +Minimum Reward: {min_reward:.2f} +Recent Trend: {'+' if recent_trend > 0 else ''}{recent_trend:.2f} + +INTERPRETATION: +═══════════════════════════════════════ +• Final reward > 0: Agent winning more than losing +• Final reward < 0: Agent needs improvement +• Upward trend: Learning successfully +• Downward trend: May need reward tuning + +CURRENT STATUS: +═══════════════════════════════════════ +""" + if final_avg_reward > 0: + stats_text += "✓ Agent is performing well (positive rewards)" + elif final_avg_reward > -100: + stats_text += "⚠ Agent is improving but still negative rewards" + else: + stats_text += "✗ Agent struggling (highly negative rewards)" + + if recent_trend > 5: + stats_text += "\n✓ Strong upward trend - keep training!" + elif recent_trend > 0: + stats_text += "\n✓ Slow improvement - patience needed" + elif recent_trend > -5: + stats_text += "\n⚠ Plateauing - may need hyperparameter tuning" + else: + stats_text += "\n✗ Declining performance - check reward function" + + ax3.text(0.1, 0.5, stats_text, fontsize=10, fontfamily='monospace', + verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) + + # Reward distribution histogram + ax4 = plt.subplot(2, 2, 4) + ax4.hist(y_smoothed, bins=50, alpha=0.7, color='steelblue', edgecolor='black') + ax4.axvline(x=0, color='red', linestyle='--', linewidth=2, label='Zero Reward') + ax4.set_xlabel("Episode Reward", fontsize=11) + ax4.set_ylabel("Frequency", fontsize=11) + ax4.set_title("Reward Distribution", fontsize=12, fontweight='bold') + ax4.legend() + ax4.grid(True, alpha=0.3) + + plt.suptitle(f"Training Analysis: {os.path.basename(experiment_path)}", + fontsize=14, fontweight='bold', y=0.98) + plt.tight_layout(rect=[0, 0, 1, 0.97]) + + # Save + output_path = os.path.join(experiment_path, "Learning Curve.png") + plt.savefig(output_path, dpi=150, bbox_inches='tight') + print(f"✓ Learning curve saved to: {output_path}") + + # Also save full analysis + analysis_path = os.path.join(experiment_path, "Training Analysis.png") + plt.savefig(analysis_path, dpi=150, bbox_inches='tight') + print(f"✓ Full analysis saved to: {analysis_path}") + + plt.close() + + # Print summary + print(f"\n{'='*60}") + print(f"SUMMARY for {os.path.basename(experiment_path)}") + print(f"{'='*60}") + print(f"Total Steps: {total_steps/1e6:.2f}M") + print(f"Final Avg Reward: {final_avg_reward:.2f}") + print(f"Recent Trend: {recent_trend:+.2f} per 1000 episodes") + print(f"{'='*60}\n") + +if __name__ == "__main__": + # Default to latest experiment + if len(sys.argv) > 1: + experiment_path = sys.argv[1] + else: + # Find latest experiment + checkpoint_dir = "checkpoints" + if os.path.exists(checkpoint_dir): + experiments = [d for d in os.listdir(checkpoint_dir) + if os.path.isdir(os.path.join(checkpoint_dir, d)) + and d.startswith('experiment')] + if experiments: + # Get most recently modified + experiments.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) + experiment_path = os.path.join(checkpoint_dir, experiments[0]) + print(f"Using latest experiment: {experiment_path}") + else: + print("No experiments found! Please specify path:") + print(f" python user/plot_learning_curve.py checkpoints/experiment_name") + sys.exit(1) + else: + print("checkpoints directory not found!") + sys.exit(1) + + plot_learning_curve(experiment_path) + diff --git a/user/pvp_match.py b/user/pvp_match.py index cb52bba..29bb0cd 100644 --- a/user/pvp_match.py +++ b/user/pvp_match.py @@ -2,6 +2,7 @@ from environment.agent import run_real_time_match from user.train_agent import UserInputAgent, BasedAgent, ConstantAgent, ClockworkAgent, SB3Agent, RecurrentPPOAgent #add anymore custom Agents (from train_agent.py) here as needed from user.my_agent import SubmittedAgent +from user.opp_agent import SubmittedAgent as opp import pygame pygame.init() @@ -10,12 +11,13 @@ #Input your file path here in SubmittedAgent if you are loading a model: opponent = SubmittedAgent(file_path=None) -match_time = 99999 +match_time = 999999 # Run a single real-time match run_real_time_match( - agent_1=my_agent, - agent_2=opponent, - max_timesteps=30 * 999990000, # Match time in frames (adjust as needed) + agent_1=opponent, # Your AI + agent_2=opp(file_path="checkpoints/v2.zip"), # You + #agent_2=my_agent, + max_timesteps=30 * 9999, # Match time in frames (adjust as needed) resolution=CameraResolution.LOW, -) \ No newline at end of file +) diff --git a/user/train_agent.py b/user/train_agent.py index 7356155..b7f285f 100644 --- a/user/train_agent.py +++ b/user/train_agent.py @@ -100,13 +100,16 @@ def _initialize(self) -> None: 'share_features_extractor': True, } + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") self.model = RecurrentPPO("MlpLstmPolicy", self.env, verbose=0, n_steps=30*90*20, batch_size=16, ent_coef=0.05, - policy_kwargs=policy_kwargs) + policy_kwargs=policy_kwargs, + device=device) del self.env else: self.model = RecurrentPPO.load(self.file_path) @@ -403,7 +406,9 @@ def damage_interaction_reward( else: raise ValueError(f"Invalid mode: {mode}") - return reward / 140 + # Clip per-step reward to [-0.2, +0.2] + reward = reward / 140 + return clip_reward(reward) # In[ ]: @@ -513,16 +518,27 @@ def on_win_reward(env: WarehouseBrawl, agent: str) -> float: return -1.0 def on_knockout_reward(env: WarehouseBrawl, agent: str) -> float: + """ + Two separate events: + - You KO them: +50 + - You get KO'd: -50 + """ + # This signal is emitted for the agent that got KO'd + # So 'player' means player got KO'd, opponent got KO'd means win if agent == 'player': - return -1.0 + return -1.0 # Player got KO'd (will be scaled to -50) else: - return 1.0 + return 1.0 # Opponent got KO'd, player wins (will be scaled to +50) def on_equip_reward(env: WarehouseBrawl, agent: str) -> float: + """ + Simple equip reward to prevent weapon camping. + No special bonuses for specific weapons. + """ if agent == "player": - if env.objects["player"].weapon == "Hammer": - return 2.0 - elif env.objects["player"].weapon == "Spear": + # Just acknowledge weapon pickup, no special bonuses + player = env.objects["player"] + if player.weapon in ["Hammer", "Spear"]: return 1.0 return 0.0 @@ -533,31 +549,327 @@ def on_drop_reward(env: WarehouseBrawl, agent: str) -> float: return 0.0 def on_combo_reward(env: WarehouseBrawl, agent: str) -> float: + """Reward per EXTRA hit after the first - prevents dwarfing KO credit""" if agent == 'player': - return -1.0 + damage_dealt = env.objects["opponent"].damage_taken_this_frame + # Pay per extra hit (first hit gets no combo bonus, subsequent hits do) + # This encourages extending but doesn't overshadow KOs + return 0.05 # Per extra hit, scaled by weight else: - return 1.0 + return -0.05 + +def edge_to_ko_bonus(env: WarehouseBrawl, agent: str) -> float: + """Bonus for converting edge hits into KOs within 2 seconds""" + if agent != 'player': + return 0.0 + + player = env.objects["player"] + opponent = env.objects["opponent"] + + # Check if opponent was near edge (within 3 units) when KO'd + edge_x = env.stage_width_tiles // 2 + opponent_dist_to_left = abs(opponent.body.position.x + edge_x) + opponent_dist_to_right = abs(opponent.body.position.x - edge_x) + min_edge_dist = min(opponent_dist_to_left, opponent_dist_to_right) + + # Bonus if opponent was near edge when KO'd + if min_edge_dist < 3.0: + return 8.0 + + return 0.0 + +def whiff_punishment_reward(env: WarehouseBrawl) -> float: + """ + STRONG penalties for whiffing attacks, especially when far from opponent. + Stops spam attacking the air. + """ + player = env.objects["player"] + opponent = env.objects["opponent"] + + # Calculate distance + dx = player.body.position.x - opponent.body.position.x + dy = player.body.position.y - opponent.body.position.y + distance = (dx**2 + dy**2)**0.5 + + threat_range = 3.0 + far_range = 5.0 + + # Check if player is attacking + if hasattr(player.state, 'move_type') and player.state.move_type != MoveType.NONE: + # STRONGER penalty if attacking when opponent is far away + if distance > far_range: + return -0.15 # Heavy penalty for attacking air when far + elif distance > threat_range: + return -0.08 # Medium penalty for attacking outside threat range + elif distance < threat_range: + # Classify as heavy or light whiff in range + is_heavy = player.state.move_type in [ + MoveType.NSIG, MoveType.DSIG, MoveType.SSIG, MoveType.DAIR, MoveType.SAIR + ] + + if is_heavy: + return -0.10 # INCREASED: Heavy whiff in range + else: + return -0.05 # INCREASED: Light whiff in range + + return 0.0 + +def time_pressure_reward(env: WarehouseBrawl) -> float: + """ + Tiny constant per-step penalty to encourage decisive action. + """ + return -0.0002 # Tiny constant per step + +def retreat_penalty(env: WarehouseBrawl) -> float: + """ + Penalizes retreating after grace window, but NOT when kiting for better engage. + """ + player = env.objects["player"] + opponent = env.objects["opponent"] + + dx = player.body.position.x - opponent.body.position.x + vel_x = player.body.velocity.x + + # Check if moving away from opponent + is_retreating = (dx > 0 and vel_x > 0.1) or (dx < 0 and vel_x < -0.1) + + if is_retreating: + # Check if kiting toward a better position (moving toward opponent diagonally) + # For now, simple check: if opponent is in bad position, allow kiting + opp_in_bad_position = opponent.body.position.y > 4.0 or opponent.body.position.y < -2.0 + + if not opp_in_bad_position: + # This is a retreat, apply penalty + return -0.01 * env.dt # Grace window handled by small penalty + + return 0.0 + +def advantage_state_reward(env: WarehouseBrawl) -> float: + """ + Rewards opponent in hitstun + tiny per-frame hitstun reward. + Encourages sustained pressure. + """ + player = env.objects["player"] + opponent = env.objects["opponent"] + + # Base reward for having opponent stunned + # Check if opponent is in StunState + if isinstance(opponent.state, StunState): + base_reward = 0.05 * env.dt # Advantage state reward + hitstun_bonus = 0.02 * env.dt # Tiny per-frame hitstun reward + return base_reward + hitstun_bonus + + return 0.0 + +def clip_reward(reward: float, min_val: float = -0.2, max_val: float = 0.2) -> float: + """ + Clip per-step rewards to [-0.2, +0.2]. + Terminal rewards are NOT clipped. + """ + return max(min_val, min(max_val, reward)) + +def proximity_to_opponent_reward(env: WarehouseBrawl) -> float: + """ + Rewards getting close to opponent - encourages chasing and engagement. + Reward increases as distance decreases. + """ + player: Player = env.objects["player"] + opponent: Player = env.objects["opponent"] + + # Calculate distance + dx = player.body.position.x - opponent.body.position.x + dy = player.body.position.y - opponent.body.position.y + distance = (dx**2 + dy**2)**0.5 + + max_distance = 15.0 # Maximum arena distance + + # Don't reward if player is stunned + if isinstance(player.state, StunState): + return 0.0 + + # Reward for being close - closer = more reward + # Inverse relationship: closer gets more reward + if distance < max_distance: + reward = (max_distance - distance) / max_distance + return reward * env.dt * 0.05 # INCREASED per-frame reward + + return 0.0 + +def edge_avoidance_reward(env: WarehouseBrawl, danger_zone: float = 3.0) -> float: + """ + Penalizes agent for being near map edges, BUT only when opponent is NOT off-stage. + Allows edge-guarding when opponent is disadvantaged. + """ + player: Player = env.objects["player"] + opponent: Player = env.objects["opponent"] + + # Check if opponent is off-stage - if so, edge-guarding is valid + edge_x = env.stage_width_tiles // 2 + edge_y = env.stage_height_tiles // 2 + + opponent_dist_to_left = abs(opponent.body.position.x + edge_x) + opponent_dist_to_right = abs(opponent.body.position.x - edge_x) + opponent_dist_to_bottom = abs(opponent.body.position.y + edge_y) + opponent_dist_to_top = abs(opponent.body.position.y - edge_y) + opponent_min_dist = min(opponent_dist_to_left, opponent_dist_to_right, + opponent_dist_to_bottom, opponent_dist_to_top) + + # If opponent is off-stage, allow edge-guarding + if opponent_min_dist < 2.0 or opponent.body.position.y > 4.5: + return 0.0 # No penalty when opponent is off-stage + + # Get arena boundaries + edge_x = env.stage_width_tiles // 2 + edge_y = env.stage_height_tiles // 2 + + # Distance to each edge + dist_to_left = abs(player.body.position.x + edge_x) + dist_to_right = abs(player.body.position.x - edge_x) + dist_to_bottom = abs(player.body.position.y + edge_y) + dist_to_top = abs(player.body.position.y - edge_y) + + # Find minimum distance to any edge + min_dist_to_edge = min(dist_to_left, dist_to_right, dist_to_bottom, dist_to_top) + + # Penalty if too close to edges (when opponent is not off-stage) + # STRONGER penalty that scales more aggressively + if min_dist_to_edge < danger_zone: + # Quadratic penalty for being near edge + penalty_ratio = min_dist_to_edge / danger_zone + penalty = -((1.0 - penalty_ratio) ** 2) * 2.0 # Stronger near edges + else: + penalty = 0.0 + + # Clip per-step reward but allow stronger penalties + return clip_reward(penalty * env.dt * 2.0) # Scale up the penalty + +def fall_velocity_penalty(env: WarehouseBrawl, max_safe_velocity: float = 70.0) -> float: + """ + Penalizes rapid falling ONLY when off-stage AND recovery resources low. + Does NOT penalize fast-fall confirms on-stage. + """ + player: Player = env.objects["player"] + + edge_x = env.stage_width_tiles // 2 + edge_y = env.stage_height_tiles // 2 + + # Check if player is off-stage + dist_to_left = abs(player.body.position.x + edge_x) + dist_to_right = abs(player.body.position.x - edge_x) + dist_to_bottom = abs(player.body.position.y + edge_y) + player_min_dist = min(dist_to_left, dist_to_right, dist_to_bottom) + + is_offstage = player_min_dist < 2.0 + + # Penalize rapid falling both on-stage (near edges) and off-stage + if player.body.velocity.y < -max_safe_velocity: + # Check recovery resources (jumps/recoveries left) + has_recovery = False + if hasattr(player.state, 'jumps_left'): + has_recovery = player.state.jumps_left > 0 + if hasattr(player.state, 'recoveries_left'): + has_recovery = has_recovery or player.state.recoveries_left > 0 + + # STRONGER penalty if off-stage or near edge + is_near_edge = player_min_dist < 3.0 + + if is_offstage or (is_near_edge and not has_recovery): + velocity_penalty = abs(player.body.velocity.y) / max_safe_velocity - 1.0 + # STRONGER penalty + return clip_reward(-velocity_penalty * env.dt * 1.0) + + return 0.0 + +def survival_bonus(env: WarehouseBrawl) -> float: + """ + Small bonus for staying alive and on-stage. + Encourages not jumping off. + """ + player: Player = env.objects["player"] + + edge_x = env.stage_width_tiles // 2 + edge_y = env.stage_height_tiles // 2 + + # Check if on-stage + dist_to_left = abs(player.body.position.x + edge_x) + dist_to_right = abs(player.body.position.x - edge_x) + dist_to_bottom = abs(player.body.position.y + edge_y) + player_min_dist = min(dist_to_left, dist_to_right, dist_to_bottom) + + is_onstage = player_min_dist >= 2.0 + + if is_onstage: + return 0.01 * env.dt # Small survival bonus + + return 0.0 + +def weapon_stability_reward(env: WarehouseBrawl) -> float: + """ + Small constant reward for having any weapon (discourages constant switching). + """ + player: Player = env.objects["player"] + + # Reward for having a weapon, slightly more for better weapons + if player.weapon == "Hammer": + return 0.02 * env.dt + elif player.weapon == "Spear": + return 0.01 * env.dt + else: + return 0.0 ''' Add your dictionary of RewardFunctions here using RewTerms ''' def gen_reward_manager(): reward_functions = { - #'target_height_reward': RewTerm(func=base_height_l2, weight=0.0, params={'target_height': -4, 'obj_name': 'player'}), - 'danger_zone_reward': RewTerm(func=danger_zone_reward, weight=0.5), - 'damage_interaction_reward': RewTerm(func=damage_interaction_reward, weight=1.0), - #'head_to_middle_reward': RewTerm(func=head_to_middle_reward, weight=0.01), - #'head_to_opponent': RewTerm(func=head_to_opponent, weight=0.05), - 'penalize_attack_reward': RewTerm(func=in_state_reward, weight=-0.04, params={'desired_state': AttackState}), - 'holding_more_than_3_keys': RewTerm(func=holding_more_than_3_keys, weight=-0.01), - #'taunt_reward': RewTerm(func=in_state_reward, weight=0.2, params={'desired_state': TauntState}), + # BALANCED REWARD SYSTEM - Terminal rewards dominate + # Symmetric damage - agent pays for bad trades + 'damage_interaction_reward': RewTerm(func=damage_interaction_reward, weight=10.0, params={'mode': RewardMode.SYMMETRIC}), + + # NEW: Advantage state reward - encourages pressure maintenance (reduced to avoid accumulation) + 'advantage_state_reward': RewTerm(func=advantage_state_reward, weight=1.5), + + # Whiff punishment - STRONGER to stop spam attacking air + 'whiff_punishment_reward': RewTerm(func=whiff_punishment_reward, weight=3.0), + + # Time pressure - prevent stalling + 'time_pressure_reward': RewTerm(func=time_pressure_reward, weight=0.5), + + # INCREASED: Retreat penalty - stop running away, engage! + 'retreat_penalty': RewTerm(func=retreat_penalty, weight=1.5), + + # Reduced weapon stability - let agent decide when to switch + 'weapon_stability_reward': RewTerm(func=weapon_stability_reward, weight=0.5), + + # Contextual edge avoidance - allows edge-guarding (STRONG penalty to prevent jumping off) + 'edge_avoidance_reward': RewTerm(func=edge_avoidance_reward, weight=12.0, params={'danger_zone': 3.0}), + 'fall_velocity_penalty': RewTerm(func=fall_velocity_penalty, weight=5.0, params={'max_safe_velocity': 60.0}), + + # Survival bonus - encourage staying alive + 'survival_bonus': RewTerm(func=survival_bonus, weight=3.0), + + # INCREASED proximity/chase rewards - encourage running at opponent + 'proximity_to_opponent_reward': RewTerm(func=proximity_to_opponent_reward, weight=3.0), + 'head_to_opponent': RewTerm(func=head_to_opponent, weight=2.0), + + # Keep these disabled/zero + 'danger_zone_reward': RewTerm(func=danger_zone_reward, weight=0.0), + 'penalize_attack_reward': RewTerm(func=in_state_reward, weight=0.0), + 'holding_more_than_3_keys': RewTerm(func=holding_more_than_3_keys, weight=0.0), } signal_subscriptions = { - 'on_win_reward': ('win_signal', RewTerm(func=on_win_reward, weight=50)), - 'on_knockout_reward': ('knockout_signal', RewTerm(func=on_knockout_reward, weight=8)), - 'on_combo_reward': ('hit_during_stun', RewTerm(func=on_combo_reward, weight=5)), - 'on_equip_reward': ('weapon_equip_signal', RewTerm(func=on_equip_reward, weight=10)), - 'on_drop_reward': ('weapon_drop_signal', RewTerm(func=on_drop_reward, weight=15)) + # TERMINAL REWARDS - These dominate to ensure winning is the main goal + 'on_win_reward': ('win_signal', RewTerm(func=on_win_reward, weight=100)), + 'on_knockout_reward': ('knockout_signal', RewTerm(func=on_knockout_reward, weight=100)), # You KO them: +100, You get KO'd: -100 (DOUBLED) + # Combo per extra hit - prevents dwarfing KO (weight 6-10) + 'on_combo_reward': ('hit_during_stun', RewTerm(func=on_combo_reward, weight=8)), + + # Edge-to-KO conversion bonus + 'edge_to_ko_bonus': ('knockout_signal', RewTerm(func=edge_to_ko_bonus, weight=1.0)), + + # Weapon rewards - simple, no special bonuses + 'on_equip_reward': ('weapon_equip_signal', RewTerm(func=on_equip_reward, weight=0.5)), + 'on_drop_reward': ('weapon_drop_signal', RewTerm(func=on_drop_reward, weight=-2)) } return RewardManager(reward_functions, signal_subscriptions) @@ -568,8 +880,11 @@ def gen_reward_manager(): The main function runs training. You can change configurations such as the Agent type or opponent specifications here. ''' if __name__ == '__main__': - # Create agent + # Start FRESH with improved rewards (RECOMMENDED) my_agent = CustomAgent(sb3_class=PPO, extractor=MLPExtractor) + + # OR: Continue from checkpoint (only if you want to try adapting old behavior): + # my_agent = CustomAgent(sb3_class=PPO, file_path='checkpoints/experiment_fixed_v2/rl_model_XXXXXX_steps.zip', extractor=MLPExtractor) # Start here if you want to train from scratch. e.g: #my_agent = RecurrentPPOAgent() @@ -588,18 +903,18 @@ def gen_reward_manager(): # Set save settings here: save_handler = SaveHandler( agent=my_agent, # Agent to save - save_freq=100_000, # Save frequency + save_freq=50_000, # Save frequency - more frequent to catch good models max_saved=40, # Maximum number of saved models save_path='checkpoints', # Save path - run_name='experiment_9', - mode=SaveHandlerMode.FORCE # Save mode, FORCE or RESUME + run_name='experiment_aggressive_v3', # Fresh training with aggressive chase + no whiffs + mode=SaveHandlerMode.FORCE # Start completely fresh ) # Set opponent settings here: opponent_specification = { 'self_play': (8, selfplay_handler), - 'constant_agent': (0.5, partial(ConstantAgent)), - 'based_agent': (1.5, partial(BasedAgent)), + 'constant_agent': (2, partial(ConstantAgent)), # Increased from 0.5 to 2 + 'based_agent': (2, partial(BasedAgent)), # Increased from 1.5 to 2 } opponent_cfg = OpponentsCfg(opponents=opponent_specification) @@ -608,6 +923,6 @@ def gen_reward_manager(): save_handler, opponent_cfg, CameraResolution.LOW, - train_timesteps=1_000_000_000, + train_timesteps=10_000_000, # Continue training (total 15M from start) train_logging=TrainLogging.PLOT ) \ No newline at end of file diff --git a/validate_my_agent.py b/validate_my_agent.py new file mode 100644 index 0000000..4ed8d7c --- /dev/null +++ b/validate_my_agent.py @@ -0,0 +1,61 @@ +""" +Simple validation script - no dependencies on Supabase or video recording +""" + +from environment.agent import ConstantAgent, run_match, CameraResolution +from user.my_agent import SubmittedAgent + +print("=" * 60) +print("VALIDATING YOUR TRAINED AGENT") +print("=" * 60) + +print("\n1. Loading your trained agent...") +my_agent = SubmittedAgent() +print(" [OK] Agent loaded successfully!") + +print("\n2. Running match against ConstantAgent (dummy opponent)...") +print(" (This should be an easy win for your agent)") + +opponent = ConstantAgent() + +# Run the match WITHOUT video to avoid FFmpeg issues +stats = run_match( + my_agent, + agent_2=opponent, + video_path=None, # No video + agent_1_name='Your Agent', + agent_2_name='Constant Agent', + resolution=CameraResolution.LOW, + max_timesteps=30 * 90, # 90 seconds + train_mode=False +) + +print(f"\n{'=' * 60}") +print(f"MATCH RESULTS") +print(f"{'=' * 60}") +print(f"Your Agent:") +print(f" - Damage taken: {stats.player1.damage_taken}%") +print(f" - Lives remaining: {stats.player1.lives_left}") +print(f" - Damage dealt: {stats.player1.damage_done}") +print(f"\nOpponent (ConstantAgent):") +print(f" - Damage taken: {stats.player2.damage_taken}%") +print(f" - Lives remaining: {stats.player2.lives_left}") +print(f" - Damage dealt: {stats.player2.damage_done}") +print(f"\nMatch time: {stats.match_time:.1f} seconds") +print(f"\nResult: ", end='') + +if stats.player1_result.value == 1: + print("*** YOU WON! Your agent is working! ***") +elif stats.player1_result.value == 0: + print("DRAW (timeout or tie)") +else: + print("*** You lost (check the agent) ***") + +print(f"{'=' * 60}") +print("\nYour agent is ready to submit to the tournament!") +print("Next steps:") +print("1. Push your code to GitHub") +print("2. Run the GitHub Actions validation pipeline") +print("3. Battle other teams!") +print(f"{'=' * 60}") +