Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions metamon/backend/team_prediction/usage_stats/stat_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from termcolor import colored
import metamon
from metamon.config import format_for_agent
from metamon.backend.team_prediction.usage_stats.format_rules import (
get_valid_pokemon,
Tier,
Expand Down Expand Up @@ -476,6 +477,7 @@ def get_usage_stats(
start_date: Optional[datetime.date] = None,
end_date: Optional[datetime.date] = None,
) -> PreloadedSmogonUsageStats:
format = format_for_agent(format)
if start_date is None or start_date < EARLIEST_USAGE_STATS_DATE:
start_date = EARLIEST_USAGE_STATS_DATE
else:
Expand Down
11 changes: 11 additions & 0 deletions metamon/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,15 @@
"gen9ou",
]

FORMAT_ALIASES = {
"gen1oulongtimer": "gen1ou",
"gen9oulongtimer": "gen9ou",
}


def format_for_agent(fmt: str) -> str:
"""Lets metamon play non-standard Showdown formats by pretending they're something else"""
return FORMAT_ALIASES.get(fmt.lower(), fmt.lower())


METAMON_CACHE_DIR = os.environ.get("METAMON_CACHE_DIR", None)
13 changes: 8 additions & 5 deletions metamon/env/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from poke_env.ps_client.server_configuration import ServerConfiguration

import metamon
from metamon.config import format_for_agent
from metamon.interface import (
UniversalState,
UniversalAction,
Expand All @@ -29,7 +30,7 @@
)
from metamon.data import DATA_PATH
from metamon.data.download import download_teams
from metamon.env.metamon_player import MetamonPlayer
from metamon.env.metamon_player import MetamonPlayer, PokeAgentPlayer


METAMON_TEAM_SETS = {
Expand Down Expand Up @@ -59,7 +60,7 @@ class TeamSet(Teambuilder):
def __init__(self, team_file_dir: str, battle_format: str):
super().__init__()
self.team_file_dir = team_file_dir
self.battle_format = battle_format.lower()
self.battle_format = format_for_agent(battle_format)
self.team_files = self._find_team_files()
self._most_recent_team_file = None

Expand Down Expand Up @@ -130,21 +131,23 @@ def get_metamon_teams(

Args:
battle_format: The battle format of the team files (e.g. "gen1ou", "gen2ubers", etc.).
Showdown variants (e.g. "gen1oulongtimer") are normalized automatically.
set_name: The name of the set of teams to download. See the README for options. If a custom name is provided,
we will search the `METAMON_CACHE_DIR` for a custom team set with that name.
set_type: The type of TeamSet to return. Defaults to TeamSet.
"""
fmt = format_for_agent(battle_format)
if set_name in METAMON_TEAM_SETS:
path = download_teams(battle_format, set_name=set_name)
path = download_teams(fmt, set_name=set_name)
elif metamon.METAMON_CACHE_DIR is not None:
path = os.path.join(metamon.METAMON_CACHE_DIR, "teams", set_name)
else:
raise ValueError(f"`METAMON_CACHE_DIR` environment variable is not set!")
if not os.path.exists(path):
raise ValueError(
f"Cannot locate valid team directory for format {battle_format} at path {path}"
f"Cannot locate valid team directory for format {fmt} at path {path}"
)
return set_type(path, battle_format)
return set_type(path, fmt)


def _check_avatar(avatar: str):
Expand Down
8 changes: 7 additions & 1 deletion metamon/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from poke_env.player import BattleOrder, Player

import metamon
from metamon.config import format_for_agent
from metamon.tokenizer import PokemonTokenizer, UNKNOWN_TOKEN
from metamon.backend.replay_parser.replay_state import (
Move as ReplayMove,
Expand Down Expand Up @@ -511,6 +512,11 @@ class UniversalState:
can_tera: bool # added v3-beta
opponent_teampreview: List[str] # added v3

@property
def agent_format(self) -> str:
"""The format as presented to the agent, with Showdown variants normalized."""
return format_for_agent(self.format)

@staticmethod
def universal_conditions(condition_rep) -> str:
if not condition_rep:
Expand Down Expand Up @@ -1155,7 +1161,7 @@ def state_to_obs(self, state: UniversalState) -> dict[str, np.ndarray]:
+ self._get_move_string_features(state.opponent_prev_move, active=False)
)
full_text_list = (
[f"<{state.format}>", force_switch]
[f"<{state.agent_format}>", force_switch]
+ player_str
+ move_str
+ switch_str
Expand Down