From d2450fecbad24b67a7c93455709607787b2c2523 Mon Sep 17 00:00:00 2001 From: an-swe Date: Tue, 4 Nov 2025 20:08:34 -0800 Subject: [PATCH 01/34] add type stubs for static type checking support Add PEP 561 compliant type stub files to enable static type checkers (Pylance, Pyright, mypy) to resolve dynamically created enums. The sc2.data module creates enums at runtime using enum.Enum() with protobuf descriptors, which are invisible to static analysis tools. --- pyproject.toml | 3 + sc2/data.pyi | 216 +++++++++++++++++++++++++++++++++++++++++++++++++ sc2/py.typed | 0 3 files changed, 219 insertions(+) create mode 100644 sc2/data.pyi create mode 100644 sc2/py.typed diff --git a/pyproject.toml b/pyproject.toml index 22aac81d..71b23f93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,9 @@ dev = [ license-files = [] package-dir = { sc2 = "sc2" } +[tool.setuptools.package-data] +sc2 = ["py.typed", "*.pyi"] + [build-system] # https://packaging.python.org/en/latest/tutorials/packaging-projects/#choosing-a-build-backend # https://setuptools.pypa.io/en/latest/userguide/package_discovery.html#custom-discovery diff --git a/sc2/data.pyi b/sc2/data.pyi new file mode 100644 index 00000000..488ccc6b --- /dev/null +++ b/sc2/data.pyi @@ -0,0 +1,216 @@ +"""Type stubs for sc2.data module + +This stub provides static type information for dynamically generated enums. +The enums in sc2.data are created at runtime using enum.Enum() with protobuf +enum descriptors, which makes them invisible to static type checkers. + +This stub file (PEP 561 compliant) allows type checkers like Pylance, Pyright, +and mypy to understand the structure and members of these enums. +""" + +from enum import Enum +from typing import Dict, Set + +from sc2.ids.ability_id import AbilityId +from sc2.ids.unit_typeid import UnitTypeId + +# Enums created from sc2api_pb2 +class CreateGameError(Enum): + MissingMap: int + InvalidMapPath: int + InvalidMapData: int + InvalidMapName: int + InvalidMapHandle: int + MissingPlayerSetup: int + InvalidPlayerSetup: int + MultiplayerUnsupported: int + +class PlayerType(Enum): + Participant: int + Computer: int + Observer: int + +class Difficulty(Enum): + VeryEasy: int + Easy: int + Medium: int + MediumHard: int + Hard: int + Harder: int + VeryHard: int + CheatVision: int + CheatMoney: int + CheatInsane: int + +class AIBuild(Enum): + RandomBuild: int + Rush: int + Timing: int + Power: int + Macro: int + Air: int + +class Status(Enum): + launched: int + init_game: int + in_game: int + in_replay: int + ended: int + quit: int + unknown: int + +class Result(Enum): + Victory: int + Defeat: int + Tie: int + Undecided: int + +class Alert(Enum): + AlertError: int + AddOnComplete: int + BuildingComplete: int + BuildingUnderAttack: int + LarvaHatched: int + MergeComplete: int + MineralsExhausted: int + MorphComplete: int + MothershipComplete: int + MULEExpired: int + NuclearLaunchDetected: int + NukeComplete: int + NydusWormDetected: int + ResearchComplete: int + TrainError: int + TrainUnitComplete: int + TrainWorkerComplete: int + TransformationComplete: int + UnitUnderAttack: int + UpgradeComplete: int + VespeneExhausted: int + WarpInComplete: int + +class ChatChannel(Enum): + Broadcast: int + Team: int + +# Enums created from common_pb2 +class Race(Enum): + """StarCraft II race enum. + + Members: + NoRace: No race specified + Terran: Terran race + Zerg: Zerg race + Protoss: Protoss race + Random: Random race selection + """ + NoRace: int + Terran: int + Zerg: int + Protoss: int + Random: int + +# Enums created from raw_pb2 +class DisplayType(Enum): + Visible: int + Snapshot: int + Hidden: int + Placeholder: int + +class Alliance(Enum): + Self: int + Ally: int + Neutral: int + Enemy: int + +class CloakState(Enum): + CloakedUnknown: int + Cloaked: int + CloakedDetected: int + NotCloaked: int + CloakedAllied: int + +# Enums created from data_pb2 +class Attribute(Enum): + Light: int + Armored: int + Biological: int + Mechanical: int + Robotic: int + Psionic: int + Massive: int + Structure: int + Hover: int + Heroic: int + Summoned: int + +class TargetType(Enum): + Ground: int + Air: int + Any: int + +class Target(Enum): + # Note: The protobuf enum member 'None' is a Python keyword, + # so at runtime it may need special handling + Point: int + Unit: int + PointOrUnit: int + PointOrNone: int + +# Enums created from error_pb2 +class ActionResult(Enum): + """Action result codes from game engine. + + This enum contains a large number of members (~200+) representing + various action results and error conditions. Only the most commonly + used members are listed here. All members are available at runtime. + """ + Success: int + NotSupported: int + Error: int + CantQueueThatOrder: int + Retry: int + Cooldown: int + QueueIsFull: int + RallyQueueIsFull: int + NotEnoughMinerals: int + NotEnoughVespene: int + NotEnoughTerrazine: int + NotEnoughCustom: int + NotEnoughFood: int + FoodUsageImpossible: int + NotEnoughLife: int + NotEnoughShields: int + NotEnoughEnergy: int + LifeSuppressed: int + ShieldsSuppressed: int + EnergySuppressed: int + NotEnoughCharges: int + CantAddMoreCharges: int + TooMuchMinerals: int + TooMuchVespene: int + TooMuchTerrazine: int + TooMuchCustom: int + TooMuchFood: int + TooMuchLife: int + TooMuchShields: int + TooMuchEnergy: int + MustTargetUnitWithLife: int + MustTargetUnitWithShields: int + MustTargetUnitWithEnergy: int + CantTrade: int + CantSpend: int + CantTargetThatUnit: int + CouldntAllocateUnit: int + UnitCantMove: int + TransportIsHoldingPosition: int + BuildTechRequirementsNotMet: int + CantFindPlacementLocation: int + CantBuildOnThat: int + # ... approximately 150+ more members exist at runtime + +# Module-level dictionaries +race_worker: Dict[Race, UnitTypeId] +race_townhalls: Dict[Race, Set[UnitTypeId]] +warpgate_abilities: Dict[AbilityId, AbilityId] +race_gas: Dict[Race, UnitTypeId] diff --git a/sc2/py.typed b/sc2/py.typed new file mode 100644 index 00000000..e69de29b From 46a43830d77b615a798422f05d1454a4c466adfa Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Mon, 10 Nov 2025 20:39:46 +0100 Subject: [PATCH 02/34] Start adding stubs for s2clientprotocol --- pyproject.toml | 2 +- s2clientprotocol/__init__.pyi | 0 s2clientprotocol/common_pb2.pyi | 26 +++++ sc2/constants.py | 2 +- sc2/data.py | 5 +- sc2/data.pyi | 195 ++++++++++++++++++++++++++++++-- sc2/position.py | 4 +- 7 files changed, 216 insertions(+), 18 deletions(-) create mode 100644 s2clientprotocol/__init__.pyi create mode 100644 s2clientprotocol/common_pb2.pyi diff --git a/pyproject.toml b/pyproject.toml index 71b23f93..f634fb32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ dev = [ [tool.setuptools] license-files = [] -package-dir = { sc2 = "sc2" } +package-dir = { sc2 = "sc2", s2clientprotocol = "s2clientprotocol" } [tool.setuptools.package-data] sc2 = ["py.typed", "*.pyi"] diff --git a/s2clientprotocol/__init__.pyi b/s2clientprotocol/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/s2clientprotocol/common_pb2.pyi b/s2clientprotocol/common_pb2.pyi new file mode 100644 index 00000000..af539a99 --- /dev/null +++ b/s2clientprotocol/common_pb2.pyi @@ -0,0 +1,26 @@ +# https://github.com/Blizzard/s2client-proto/blob/bff45dae1fc685e6acbaae084670afb7d1c0832c/s2clientprotocol/common.proto +from enum import Enum +from google.protobuf.message import Message + +class PointI(Message): + x: int + y: int + def __init__(self, x: int = ..., y: int = ...) -> None: ... + +class Point2D(Message): + x: float + y: float + def __init__(self, x: float = ..., y: float = ...) -> None: ... + +class Point(Message): + x: float + y: float + z: float + def __init__(self, x: float = ..., y: float = ..., z: float = ...) -> None: ... + +class Race(Enum): + NoRace: int + Terran: int + Zerg: int + Protoss: int + Random: int diff --git a/sc2/constants.py b/sc2/constants.py index e23d4c13..6478ff01 100644 --- a/sc2/constants.py +++ b/sc2/constants.py @@ -495,7 +495,7 @@ def return_NOTAUNIT() -> UnitTypeId: UnitTypeId.EXTRACTORRICH, } # pyre-ignore[11] -DAMAGE_BONUS_PER_UPGRADE: dict[UnitTypeId, dict[TargetType, Any]] = { +DAMAGE_BONUS_PER_UPGRADE: dict[UnitTypeId, dict[int, Any]] = { # # Protoss # diff --git a/sc2/data.py b/sc2/data.py index b0c9425f..a4377b8a 100644 --- a/sc2/data.py +++ b/sc2/data.py @@ -1,10 +1,7 @@ # pyre-ignore-all-errors[16, 19] """For the list of enums, see here -https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_gametypes.h -https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_action.h -https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_unit.h -https://github.com/Blizzard/s2client-api/blob/d9ba0a33d6ce9d233c2a4ee988360c188fbe9dbf/include/sc2api/sc2_data.h +https://github.com/Blizzard/s2client-proto/tree/bff45dae1fc685e6acbaae084670afb7d1c0832c/s2clientprotocol """ from __future__ import annotations diff --git a/sc2/data.pyi b/sc2/data.pyi index 488ccc6b..b625d769 100644 --- a/sc2/data.pyi +++ b/sc2/data.pyi @@ -8,8 +8,9 @@ This stub file (PEP 561 compliant) allows type checkers like Pylance, Pyright, and mypy to understand the structure and members of these enums. """ +from __future__ import annotations + from enum import Enum -from typing import Dict, Set from sc2.ids.ability_id import AbilityId from sc2.ids.unit_typeid import UnitTypeId @@ -96,7 +97,7 @@ class ChatChannel(Enum): # Enums created from common_pb2 class Race(Enum): """StarCraft II race enum. - + Members: NoRace: No race specified Terran: Terran race @@ -104,6 +105,7 @@ class Race(Enum): Protoss: Protoss race Random: Random race selection """ + NoRace: int Terran: int Zerg: int @@ -143,11 +145,13 @@ class Attribute(Enum): Hover: int Heroic: int Summoned: int + Invalid: int class TargetType(Enum): Ground: int Air: int Any: int + Invalid: int class Target(Enum): # Note: The protobuf enum member 'None' is a Python keyword, @@ -160,11 +164,11 @@ class Target(Enum): # Enums created from error_pb2 class ActionResult(Enum): """Action result codes from game engine. - + This enum contains a large number of members (~200+) representing - various action results and error conditions. Only the most commonly - used members are listed here. All members are available at runtime. + various action results and error conditions. """ + Success: int NotSupported: int Error: int @@ -207,10 +211,181 @@ class ActionResult(Enum): BuildTechRequirementsNotMet: int CantFindPlacementLocation: int CantBuildOnThat: int - # ... approximately 150+ more members exist at runtime + CantBuildTooCloseToDropOff: int + CantBuildLocationInvalid: int + CantSeeBuildLocation: int + CantBuildTooCloseToCreepSource: int + CantBuildTooCloseToResources: int + CantBuildTooFarFromWater: int + CantBuildTooFarFromCreepSource: int + CantBuildTooFarFromBuildPowerSource: int + CantBuildOnDenseTerrain: int + CantTrainTooFarFromTrainPowerSource: int + CantLandLocationInvalid: int + CantSeeLandLocation: int + CantLandTooCloseToCreepSource: int + CantLandTooCloseToResources: int + CantLandTooFarFromWater: int + CantLandTooFarFromCreepSource: int + CantLandTooFarFromBuildPowerSource: int + CantLandTooFarFromTrainPowerSource: int + CantLandOnDenseTerrain: int + AddOnTooFarFromBuilding: int + MustBuildRefineryFirst: int + BuildingIsUnderConstruction: int + CantFindDropOff: int + CantLoadOtherPlayersUnits: int + NotEnoughRoomToLoadUnit: int + CantUnloadUnitsThere: int + CantWarpInUnitsThere: int + CantLoadImmobileUnits: int + CantRechargeImmobileUnits: int + CantRechargeUnderConstructionUnits: int + CantLoadThatUnit: int + NoCargoToUnload: int + LoadAllNoTargetsFound: int + NotWhileOccupied: int + CantAttackWithoutAmmo: int + CantHoldAnyMoreAmmo: int + TechRequirementsNotMet: int + MustLockdownUnitFirst: int + MustTargetUnit: int + MustTargetInventory: int + MustTargetVisibleUnit: int + MustTargetVisibleLocation: int + MustTargetWalkableLocation: int + MustTargetPawnableUnit: int + YouCantControlThatUnit: int + YouCantIssueCommandsToThatUnit: int + MustTargetResources: int + RequiresHealTarget: int + RequiresRepairTarget: int + NoItemsToDrop: int + CantHoldAnyMoreItems: int + CantHoldThat: int + TargetHasNoInventory: int + CantDropThisItem: int + CantMoveThisItem: int + CantPawnThisUnit: int + MustTargetCaster: int + CantTargetCaster: int + MustTargetOuter: int + CantTargetOuter: int + MustTargetYourOwnUnits: int + CantTargetYourOwnUnits: int + MustTargetFriendlyUnits: int + CantTargetFriendlyUnits: int + MustTargetNeutralUnits: int + CantTargetNeutralUnits: int + MustTargetEnemyUnits: int + CantTargetEnemyUnits: int + MustTargetAirUnits: int + CantTargetAirUnits: int + MustTargetGroundUnits: int + CantTargetGroundUnits: int + MustTargetStructures: int + CantTargetStructures: int + MustTargetLightUnits: int + CantTargetLightUnits: int + MustTargetArmoredUnits: int + CantTargetArmoredUnits: int + MustTargetBiologicalUnits: int + CantTargetBiologicalUnits: int + MustTargetHeroicUnits: int + CantTargetHeroicUnits: int + MustTargetRoboticUnits: int + CantTargetRoboticUnits: int + MustTargetMechanicalUnits: int + CantTargetMechanicalUnits: int + MustTargetPsionicUnits: int + CantTargetPsionicUnits: int + MustTargetMassiveUnits: int + CantTargetMassiveUnits: int + MustTargetMissile: int + CantTargetMissile: int + MustTargetWorkerUnits: int + CantTargetWorkerUnits: int + MustTargetEnergyCapableUnits: int + CantTargetEnergyCapableUnits: int + MustTargetShieldCapableUnits: int + CantTargetShieldCapableUnits: int + MustTargetFlyers: int + CantTargetFlyers: int + MustTargetBuriedUnits: int + CantTargetBuriedUnits: int + MustTargetCloakedUnits: int + CantTargetCloakedUnits: int + MustTargetUnitsInAStasisField: int + CantTargetUnitsInAStasisField: int + MustTargetUnderConstructionUnits: int + CantTargetUnderConstructionUnits: int + MustTargetDeadUnits: int + CantTargetDeadUnits: int + MustTargetRevivableUnits: int + CantTargetRevivableUnits: int + MustTargetHiddenUnits: int + CantTargetHiddenUnits: int + CantRechargeOtherPlayersUnits: int + MustTargetHallucinations: int + CantTargetHallucinations: int + MustTargetInvulnerableUnits: int + CantTargetInvulnerableUnits: int + MustTargetDetectedUnits: int + CantTargetDetectedUnits: int + CantTargetUnitWithEnergy: int + CantTargetUnitWithShields: int + MustTargetUncommandableUnits: int + CantTargetUncommandableUnits: int + MustTargetPreventDefeatUnits: int + CantTargetPreventDefeatUnits: int + MustTargetPreventRevealUnits: int + CantTargetPreventRevealUnits: int + MustTargetPassiveUnits: int + CantTargetPassiveUnits: int + MustTargetStunnedUnits: int + CantTargetStunnedUnits: int + MustTargetSummonedUnits: int + CantTargetSummonedUnits: int + MustTargetUser1: int + CantTargetUser1: int + MustTargetUnstoppableUnits: int + CantTargetUnstoppableUnits: int + MustTargetResistantUnits: int + CantTargetResistantUnits: int + MustTargetDazedUnits: int + CantTargetDazedUnits: int + CantLockdown: int + CantMindControl: int + MustTargetDestructibles: int + CantTargetDestructibles: int + MustTargetItems: int + CantTargetItems: int + NoCalldownAvailable: int + WaypointListFull: int + MustTargetRace: int + CantTargetRace: int + MustTargetSimilarUnits: int + CantTargetSimilarUnits: int + CantFindEnoughTargets: int + AlreadySpawningLarva: int + CantTargetExhaustedResources: int + CantUseMinimap: int + CantUseInfoPanel: int + OrderQueueIsFull: int + CantHarvestThatResource: int + HarvestersNotRequired: int + AlreadyTargeted: int + CantAttackWeaponsDisabled: int + CouldntReachTarget: int + TargetIsOutOfRange: int + TargetIsTooClose: int + TargetIsOutOfArc: int + CantFindTeleportLocation: int + InvalidItemClass: int + CantFindCancelOrder: int # Module-level dictionaries -race_worker: Dict[Race, UnitTypeId] -race_townhalls: Dict[Race, Set[UnitTypeId]] -warpgate_abilities: Dict[AbilityId, AbilityId] -race_gas: Dict[Race, UnitTypeId] +race_worker: dict[Race, UnitTypeId] +race_townhalls: dict[Race, set[UnitTypeId]] +warpgate_abilities: dict[AbilityId, AbilityId] +race_gas: dict[Race, UnitTypeId] diff --git a/sc2/position.py b/sc2/position.py index 36a0922f..3e37ad2f 100644 --- a/sc2/position.py +++ b/sc2/position.py @@ -142,7 +142,7 @@ def __hash__(self) -> int: class Point2(Pointlike): @classmethod - def from_proto(cls, data) -> Point2: + def from_proto(cls, data: common_pb.Point2D) -> Point2: """ :param data: """ @@ -324,7 +324,7 @@ def center(points: list[Point2]) -> Point2: class Point3(Point2): @classmethod - def from_proto(cls, data) -> Point3: + def from_proto(cls, data: common_pb.Point) -> Point3: """ :param data: """ From 331fa88c6b6a5102933d004f9089bc465871f5e9 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Thu, 13 Nov 2025 18:32:12 +0100 Subject: [PATCH 03/34] Finalize common_pb2 --- s2clientprotocol/common_pb2.pyi | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/s2clientprotocol/common_pb2.pyi b/s2clientprotocol/common_pb2.pyi index af539a99..2a4028c7 100644 --- a/s2clientprotocol/common_pb2.pyi +++ b/s2clientprotocol/common_pb2.pyi @@ -2,11 +2,27 @@ from enum import Enum from google.protobuf.message import Message +class AvailableAbility(Message): + ability_id: int + requires_point: bool + def __init__(self, ability_id: int = ..., requires_point: bool = ...) -> None: ... + +class ImageData(Message): + bits_per_pixel: int + size: Size2DI + data: bytes + def __init__(self, bits_per_pixel: int = ..., size: Size2DI = ..., data: bytes = ...) -> None: ... + class PointI(Message): x: int y: int def __init__(self, x: int = ..., y: int = ...) -> None: ... +class RectangleI(Message): + p0: PointI + p1: PointI + def __init__(self, p0: PointI = ..., p1: PointI = ...) -> None: ... + class Point2D(Message): x: float y: float @@ -18,6 +34,11 @@ class Point(Message): z: float def __init__(self, x: float = ..., y: float = ..., z: float = ...) -> None: ... +class Size2DI(Message): + x: int + y: int + def __init__(self, x: int = ..., y: int = ...) -> None: ... + class Race(Enum): NoRace: int Terran: int From bf3e6e7be47e35489560d5692fb9e830b92bcc54 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 02:08:33 +0100 Subject: [PATCH 04/34] Finalize data_pb2 --- s2clientprotocol/data_pb2.pyi | 169 ++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 s2clientprotocol/data_pb2.pyi diff --git a/s2clientprotocol/data_pb2.pyi b/s2clientprotocol/data_pb2.pyi new file mode 100644 index 00000000..2ba8f190 --- /dev/null +++ b/s2clientprotocol/data_pb2.pyi @@ -0,0 +1,169 @@ +from enum import Enum +from google.protobuf.message import Message +from .common_pb2 import Race + +class Target(Enum): + # NONE: int + Point: int + Unit: int + PointOrUnit: int + PointOrNone: int + +class AbilityData(Message): + ability_id: int + link_name: str + link_index: int + button_name: str + friendly_name: str + hotkey: str + remaps_to_ability_id: int + available: bool + target: Target + allow_minimap: bool + allow_autocast: bool + is_building: bool + footprint_radius: float + is_instant_placement: bool + cast_range: float + def __init__( + self, + ability_id: int = ..., + link_name: str = ..., + link_index: int = ..., + button_name: str = ..., + friendly_name: str = ..., + hotkey: str = ..., + remaps_to_ability_id: int = ..., + available: bool = ..., + target: Target = ..., + allow_minimap: bool = ..., + allow_autocast: bool = ..., + is_building: bool = ..., + footprint_radius: float = ..., + is_instant_placement: bool = ..., + cast_range: float = ..., + ) -> None: ... + +class Attribute(Enum): + Light: int + Armored: int + Biological: int + Mechanical: int + Robotic: int + Psionic: int + Massive: int + Structure: int + Hover: int + Heroic: int + Summoned: int + +class DamageBonus(Message): + attribute: Attribute + bonus: float + def __init__(self, attribute: Attribute = ..., bonus: float = ...) -> None: ... + +class TargetType(Enum): + Ground: int + Air: int + Any: int + +class Weapon(Message): + type: TargetType + damage: float + damage_bonus: list[DamageBonus] + attacks: int + range: float + speed: float + def __init__( + self, + type: TargetType = ..., + damage: float = ..., + damage_bonus: list[DamageBonus] = ..., + attacks: int = ..., + range: float = ..., + speed: float = ..., + ) -> None: ... + +class UnitTypeData(Message): + unit_id: int + name: str + available: bool + cargo_size: int + mineral_cost: int + vespene_cost: int + food_required: float + food_provided: float + ability_id: int + race: Race + build_time: float + has_vespene: bool + has_minerals: bool + sight_range: float + tech_alias: list[int] + unit_alias: int + tech_requirement: int + require_attached: bool + attributes: list[Attribute] + movement_speed: float + armor: float + weapons: list[Weapon] + def __init__( + self, + unit_id: int = ..., + name: str = ..., + available: bool = ..., + cargo_size: int = ..., + mineral_cost: int = ..., + vespene_cost: int = ..., + food_required: float = ..., + food_provided: float = ..., + ability_id: int = ..., + race: Race = ..., + build_time: float = ..., + has_vespene: bool = ..., + has_minerals: bool = ..., + sight_range: float = ..., + tech_alias: list[int] = ..., + unit_alias: int = ..., + tech_requirement: int = ..., + require_attached: bool = ..., + attributes: list[Attribute] = ..., + movement_speed: float = ..., + armor: float = ..., + weapons: list[Weapon] = ..., + ) -> None: ... + +class UpgradeData(Message): + upgrade_id: int + name: str + mineral_cost: int + vespene_cost: int + research_time: float + ability_id: int + def __init__( + self, + upgrade_id: int = ..., + name: str = ..., + mineral_cost: int = ..., + vespene_cost: int = ..., + research_time: float = ..., + ability_id: int = ..., + ) -> None: ... + +class BuffData(Message): + buff_id: int + name: str + def __init__(self, buff_id: int = ..., name: str = ...) -> None: ... + +class EffectData(Message): + effect_id: int + name: str + friendly_name: str + radius: float + def __init__( + self, + effect_id: int = ..., + name: str = ..., + friendly_name: str = ..., + radius: float = ..., + ) -> None: ... From 62b8f39c605c8db6ad9d192628f65c6c7aab61e7 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 02:14:46 +0100 Subject: [PATCH 05/34] Finalize debug_pb2 --- s2clientprotocol/debug_pb2.pyi | 149 +++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 s2clientprotocol/debug_pb2.pyi diff --git a/s2clientprotocol/debug_pb2.pyi b/s2clientprotocol/debug_pb2.pyi new file mode 100644 index 00000000..ea6617fe --- /dev/null +++ b/s2clientprotocol/debug_pb2.pyi @@ -0,0 +1,149 @@ +from enum import Enum +from google.protobuf.message import Message +from .common_pb2 import Point, Point2D + +class DebugCommand(Message): + draw: DebugDraw + game_state: DebugGameState + create_unit: DebugCreateUnit + kill_unit: DebugKillUnit + test_process: DebugTestProcess + score: DebugSetScore + end_game: DebugEndGame + unit_value: DebugSetUnitValue + def __init__( + self, + draw: DebugDraw = ..., + game_state: DebugGameState = ..., + create_unit: DebugCreateUnit = ..., + kill_unit: DebugKillUnit = ..., + test_process: DebugTestProcess = ..., + score: DebugSetScore = ..., + end_game: DebugEndGame = ..., + unit_value: DebugSetUnitValue = ..., + ) -> None: ... + +class DebugDraw(Message): + text: list[DebugText] + lines: list[DebugLine] + boxes: list[DebugBox] + spheres: list[DebugSphere] + def __init__( + self, + text: list[DebugText] = ..., + lines: list[DebugLine] = ..., + boxes: list[DebugBox] = ..., + spheres: list[DebugSphere] = ..., + ) -> None: ... + +class Line(Message): + p0: Point + p1: Point + def __init__(self, p0: Point = ..., p1: Point = ...) -> None: ... + +class Color(Message): + r: int + g: int + b: int + def __init__(self, r: int = ..., g: int = ..., b: int = ...) -> None: ... + +class DebugText(Message): + color: Color + text: str + virtual_pos: Point + world_pos: Point + size: int + def __init__( + self, + color: Color = ..., + text: str = ..., + virtual_pos: Point = ..., + world_pos: Point = ..., + size: int = ..., + ) -> None: ... + +class DebugLine(Message): + color: Color + line: Line + def __init__(self, color: Color = ..., line: Line = ...) -> None: ... + +class DebugBox(Message): + color: Color + min: Point + max: Point + def __init__(self, color: Color = ..., min: Point = ..., max: Point = ...) -> None: ... + +class DebugSphere(Message): + color: Color + p: Point + r: float + def __init__(self, color: Color = ..., p: Point = ..., r: float = ...) -> None: ... + +class DebugGameState(Enum): + show_map: int + control_enemy: int + food: int + free: int + all_resources: int + god: int + minerals: int + gas: int + cooldown: int + tech_tree: int + upgrade: int + fast_build: int + +class DebugCreateUnit(Message): + unit_type: int + owner: int + pos: Point2D + quantity: int + def __init__( + self, + unit_type: int = ..., + owner: int = ..., + pos: Point2D = ..., + quantity: int = ..., + ) -> None: ... + +class DebugKillUnit(Message): + tag: list[int] + def __init__(self, tag: list[int] = ...) -> None: ... + +class Test(Enum): + hang: int + crash: int + exit: int + +class DebugTestProcess(Message): + test: Test + delay_ms: int + def __init__(self, test: Test = ..., delay_ms: int = ...) -> None: ... + +class DebugSetScore(Message): + score: float + def __init__(self, score: float = ...) -> None: ... + +class EndResult(Enum): + Surrender: int + DeclareVictory: int + +class DebugEndGame(Message): + end_result: EndResult + def __init__(self, end_result: EndResult = ...) -> None: ... + +class UnitValue(Enum): + Energy: int + Life: int + Shields: int + +class DebugSetUnitValue(Message): + unit_value: UnitValue + value: float + unit_tag: int + def __init__( + self, + unit_value: UnitValue = ..., + value: float = ..., + unit_tag: int = ..., + ) -> None: ... From 4d45c45ad64b50878f6f34abbf5d6b9b10919e11 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 02:15:26 +0100 Subject: [PATCH 06/34] Finalize error_pb2 --- s2clientprotocol/error_pb2.pyi | 217 +++++++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 s2clientprotocol/error_pb2.pyi diff --git a/s2clientprotocol/error_pb2.pyi b/s2clientprotocol/error_pb2.pyi new file mode 100644 index 00000000..017262d9 --- /dev/null +++ b/s2clientprotocol/error_pb2.pyi @@ -0,0 +1,217 @@ +from enum import Enum + +class ActionResult(Enum): + Success: int + NotSupported: int + Error: int + CantQueueThatOrder: int + Retry: int + Cooldown: int + QueueIsFull: int + RallyQueueIsFull: int + NotEnoughMinerals: int + NotEnoughVespene: int + NotEnoughTerrazine: int + NotEnoughCustom: int + NotEnoughFood: int + FoodUsageImpossible: int + NotEnoughLife: int + NotEnoughShields: int + NotEnoughEnergy: int + LifeSuppressed: int + ShieldsSuppressed: int + EnergySuppressed: int + NotEnoughCharges: int + CantAddMoreCharges: int + TooMuchMinerals: int + TooMuchVespene: int + TooMuchTerrazine: int + TooMuchCustom: int + TooMuchFood: int + TooMuchLife: int + TooMuchShields: int + TooMuchEnergy: int + MustTargetUnitWithLife: int + MustTargetUnitWithShields: int + MustTargetUnitWithEnergy: int + CantTrade: int + CantSpend: int + CantTargetThatUnit: int + CouldntAllocateUnit: int + UnitCantMove: int + TransportIsHoldingPosition: int + BuildTechRequirementsNotMet: int + CantFindPlacementLocation: int + CantBuildOnThat: int + CantBuildTooCloseToDropOff: int + CantBuildLocationInvalid: int + CantSeeBuildLocation: int + CantBuildTooCloseToCreepSource: int + CantBuildTooCloseToResources: int + CantBuildTooFarFromWater: int + CantBuildTooFarFromCreepSource: int + CantBuildTooFarFromBuildPowerSource: int + CantBuildOnDenseTerrain: int + CantTrainTooFarFromTrainPowerSource: int + CantLandLocationInvalid: int + CantSeeLandLocation: int + CantLandTooCloseToCreepSource: int + CantLandTooCloseToResources: int + CantLandTooFarFromWater: int + CantLandTooFarFromCreepSource: int + CantLandTooFarFromBuildPowerSource: int + CantLandTooFarFromTrainPowerSource: int + CantLandOnDenseTerrain: int + AddOnTooFarFromBuilding: int + MustBuildRefineryFirst: int + BuildingIsUnderConstruction: int + CantFindDropOff: int + CantLoadOtherPlayersUnits: int + NotEnoughRoomToLoadUnit: int + CantUnloadUnitsThere: int + CantWarpInUnitsThere: int + CantLoadImmobileUnits: int + CantRechargeImmobileUnits: int + CantRechargeUnderConstructionUnits: int + CantLoadThatUnit: int + NoCargoToUnload: int + LoadAllNoTargetsFound: int + NotWhileOccupied: int + CantAttackWithoutAmmo: int + CantHoldAnyMoreAmmo: int + TechRequirementsNotMet: int + MustLockdownUnitFirst: int + MustTargetUnit: int + MustTargetInventory: int + MustTargetVisibleUnit: int + MustTargetVisibleLocation: int + MustTargetWalkableLocation: int + MustTargetPawnableUnit: int + YouCantControlThatUnit: int + YouCantIssueCommandsToThatUnit: int + MustTargetResources: int + RequiresHealTarget: int + RequiresRepairTarget: int + NoItemsToDrop: int + CantHoldAnyMoreItems: int + CantHoldThat: int + TargetHasNoInventory: int + CantDropThisItem: int + CantMoveThisItem: int + CantPawnThisUnit: int + MustTargetCaster: int + CantTargetCaster: int + MustTargetOuter: int + CantTargetOuter: int + MustTargetYourOwnUnits: int + CantTargetYourOwnUnits: int + MustTargetFriendlyUnits: int + CantTargetFriendlyUnits: int + MustTargetNeutralUnits: int + CantTargetNeutralUnits: int + MustTargetEnemyUnits: int + CantTargetEnemyUnits: int + MustTargetAirUnits: int + CantTargetAirUnits: int + MustTargetGroundUnits: int + CantTargetGroundUnits: int + MustTargetStructures: int + CantTargetStructures: int + MustTargetLightUnits: int + CantTargetLightUnits: int + MustTargetArmoredUnits: int + CantTargetArmoredUnits: int + MustTargetBiologicalUnits: int + CantTargetBiologicalUnits: int + MustTargetHeroicUnits: int + CantTargetHeroicUnits: int + MustTargetRoboticUnits: int + CantTargetRoboticUnits: int + MustTargetMechanicalUnits: int + CantTargetMechanicalUnits: int + MustTargetPsionicUnits: int + CantTargetPsionicUnits: int + MustTargetMassiveUnits: int + CantTargetMassiveUnits: int + MustTargetMissile: int + CantTargetMissile: int + MustTargetWorkerUnits: int + CantTargetWorkerUnits: int + MustTargetEnergyCapableUnits: int + CantTargetEnergyCapableUnits: int + MustTargetShieldCapableUnits: int + CantTargetShieldCapableUnits: int + MustTargetFlyers: int + CantTargetFlyers: int + MustTargetBuriedUnits: int + CantTargetBuriedUnits: int + MustTargetCloakedUnits: int + CantTargetCloakedUnits: int + MustTargetUnitsInAStasisField: int + CantTargetUnitsInAStasisField: int + MustTargetUnderConstructionUnits: int + CantTargetUnderConstructionUnits: int + MustTargetDeadUnits: int + CantTargetDeadUnits: int + MustTargetRevivableUnits: int + CantTargetRevivableUnits: int + MustTargetHiddenUnits: int + CantTargetHiddenUnits: int + CantRechargeOtherPlayersUnits: int + MustTargetHallucinations: int + CantTargetHallucinations: int + MustTargetInvulnerableUnits: int + CantTargetInvulnerableUnits: int + MustTargetDetectedUnits: int + CantTargetDetectedUnits: int + CantTargetUnitWithEnergy: int + CantTargetUnitWithShields: int + MustTargetUncommandableUnits: int + CantTargetUncommandableUnits: int + MustTargetPreventDefeatUnits: int + CantTargetPreventDefeatUnits: int + MustTargetPreventRevealUnits: int + CantTargetPreventRevealUnits: int + MustTargetPassiveUnits: int + CantTargetPassiveUnits: int + MustTargetStunnedUnits: int + CantTargetStunnedUnits: int + MustTargetSummonedUnits: int + CantTargetSummonedUnits: int + MustTargetUser1: int + CantTargetUser1: int + MustTargetUnstoppableUnits: int + CantTargetUnstoppableUnits: int + MustTargetResistantUnits: int + CantTargetResistantUnits: int + MustTargetDazedUnits: int + CantTargetDazedUnits: int + CantLockdown: int + CantMindControl: int + MustTargetDestructibles: int + CantTargetDestructibles: int + MustTargetItems: int + CantTargetItems: int + NoCalldownAvailable: int + WaypointListFull: int + MustTargetRace: int + CantTargetRace: int + MustTargetSimilarUnits: int + CantTargetSimilarUnits: int + CantFindEnoughTargets: int + AlreadySpawningLarva: int + CantTargetExhaustedResources: int + CantUseMinimap: int + CantUseInfoPanel: int + OrderQueueIsFull: int + CantHarvestThatResource: int + HarvestersNotRequired: int + AlreadyTargeted: int + CantAttackWeaponsDisabled: int + CouldntReachTarget: int + TargetIsOutOfRange: int + TargetIsTooClose: int + TargetIsOutOfArc: int + CantFindTeleportLocation: int + InvalidItemClass: int + CantFindCancelOrder: int From 68e49e55a0ee9887aa8e131d278c592f47dfdddb Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 02:24:21 +0100 Subject: [PATCH 07/34] Finalize query_pb2 --- s2clientprotocol/query_pb2.pyi | 72 ++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 s2clientprotocol/query_pb2.pyi diff --git a/s2clientprotocol/query_pb2.pyi b/s2clientprotocol/query_pb2.pyi new file mode 100644 index 00000000..dfd7823b --- /dev/null +++ b/s2clientprotocol/query_pb2.pyi @@ -0,0 +1,72 @@ +from google.protobuf.message import Message +from .common_pb2 import Point2D, AvailableAbility +from .error_pb2 import ActionResult + +class RequestQuery(Message): + pathing: list[RequestQueryPathing] + abilities: list[RequestQueryAvailableAbilities] + placements: list[RequestQueryBuildingPlacement] + ignore_resource_requirements: bool + def __init__( + self, + pathing: list[RequestQueryPathing] = ..., + abilities: list[RequestQueryAvailableAbilities] = ..., + placements: list[RequestQueryBuildingPlacement] = ..., + ignore_resource_requirements: bool = ..., + ) -> None: ... + +class ResponseQuery(Message): + pathing: list[ResponseQueryPathing] + abilities: list[ResponseQueryAvailableAbilities] + placements: list[ResponseQueryBuildingPlacement] + def __init__( + self, + pathing: list[ResponseQueryPathing] = ..., + abilities: list[ResponseQueryAvailableAbilities] = ..., + placements: list[ResponseQueryBuildingPlacement] = ..., + ) -> None: ... + +class RequestQueryPathing(Message): + start_pos: Point2D + unit_tag: int + end_pos: Point2D + def __init__( + self, + start_pos: Point2D = ..., + unit_tag: int = ..., + end_pos: Point2D = ..., + ) -> None: ... + +class ResponseQueryPathing(Message): + distance: float + def __init__(self, distance: float = ...) -> None: ... + +class RequestQueryAvailableAbilities(Message): + unit_tag: int + def __init__(self, unit_tag: int = ...) -> None: ... + +class ResponseQueryAvailableAbilities(Message): + abilities: list[AvailableAbility] + unit_tag: int + unit_type_id: int + def __init__( + self, + abilities: list[AvailableAbility] = ..., + unit_tag: int = ..., + unit_type_id: int = ..., + ) -> None: ... + +class RequestQueryBuildingPlacement(Message): + ability_id: int + target_pos: Point2D + placing_unit_tag: int + def __init__( + self, + ability_id: int = ..., + target_pos: Point2D = ..., + placing_unit_tag: int = ..., + ) -> None: ... + +class ResponseQueryBuildingPlacement(Message): + result: ActionResult + def __init__(self, result: ActionResult = ...) -> None: ... From eb96ff8be5005d327ee86b9b64fbce43f96a872b Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 02:53:04 +0100 Subject: [PATCH 08/34] Finalize raw_pb2 --- s2clientprotocol/raw_pb2.pyi | 269 +++++++++++++++++++++++++++++++++++ 1 file changed, 269 insertions(+) create mode 100644 s2clientprotocol/raw_pb2.pyi diff --git a/s2clientprotocol/raw_pb2.pyi b/s2clientprotocol/raw_pb2.pyi new file mode 100644 index 00000000..ac4a7090 --- /dev/null +++ b/s2clientprotocol/raw_pb2.pyi @@ -0,0 +1,269 @@ +from enum import Enum +from google.protobuf.message import Message +from .common_pb2 import Point2D, Point, Size2DI, ImageData, RectangleI + +class StartRaw(Message): + map_size: Size2DI + pathing_grid: ImageData + terrain_height: ImageData + placement_grid: ImageData + playable_area: RectangleI + start_locations: list[Point2D] + def __init__( + self, + map_size: Size2DI = ..., + pathing_grid: ImageData = ..., + terrain_height: ImageData = ..., + placement_grid: ImageData = ..., + playable_area: RectangleI = ..., + start_locations: list[Point2D] = ..., + ) -> None: ... + +class ObservationRaw(Message): + player: PlayerRaw + units: list[Unit] + map_state: MapState + event: Event + effects: list[Effect] + radar: list[RadarRing] + def __init__( + self, + player: PlayerRaw = ..., + units: list[Unit] = ..., + map_state: MapState = ..., + event: Event = ..., + effects: list[Effect] = ..., + radar: list[RadarRing] = ..., + ) -> None: ... + +class RadarRing(Message): + pos: Point + radius: float + def __init__(self, pos: Point = ..., radius: float = ...) -> None: ... + +class PowerSource(Message): + pos: Point + radius: float + tag: int + def __init__(self, pos: Point = ..., radius: float = ..., tag: int = ...) -> None: ... + +class PlayerRaw(Message): + power_sources: list[PowerSource] + camera: Point + upgrade_ids: list[int] + def __init__( + self, + power_sources: list[PowerSource] = ..., + camera: Point = ..., + upgrade_ids: list[int] = ..., + ) -> None: ... + +class UnitOrder(Message): + ability_id: int + target_world_space_pos: Point + target_unit_tag: int + progress: float + def __init__( + self, + ability_id: int = ..., + target_world_space_pos: Point = ..., + target_unit_tag: int = ..., + progress: float = ..., + ) -> None: ... + +class DisplayType(Enum): + Visible: int + Snapshot: int + Hidden: int + Placeholder: int + +class Alliance(Enum): + Self: int + Ally: int + Neutral: int + Enemy: int + +class CloakState(Enum): + CloakedUnknown: int + Cloaked: int + CloakedDetected: int + NotCloaked: int + CloakedAllied: int + +class PassengerUnit(Message): + tag: int + health: float + health_max: float + shield: float + shield_max: float + energy: float + energy_max: float + unit_type: int + def __init__( + self, + tag: int = ..., + health: float = ..., + health_max: float = ..., + shield: float = ..., + shield_max: float = ..., + energy: float = ..., + energy_max: float = ..., + unit_type: int = ..., + ) -> None: ... + +class RallyTarget(Message): + point: Point + tag: int + def __init__(self, point: Point = ..., tag: int = ...) -> None: ... + +class Unit(Message): + display_type: DisplayType + alliance: Alliance + tag: int + unit_type: int + owner: int + pos: Point + facing: float + radius: float + build_progress: float + cloak: CloakState + buff_ids: list[int] + detect_range: float + radar_range: float + is_selected: bool + is_on_screen: bool + is_blip: bool + is_powered: bool + is_active: bool + attack_upgrade_level: int + armor_upgrade_level: int + shield_upgrade_level: int + health: float + health_max: float + shield: float + shield_max: float + energy: float + energy_max: float + mineral_contents: int + vespene_contents: int + is_flying: bool + is_burrowed: bool + is_hallucination: bool + orders: list[UnitOrder] + add_on_tag: int + passengers: list[PassengerUnit] + cargo_space_taken: int + cargo_space_max: int + assigned_harvesters: int + ideal_harvesters: int + weapon_cooldown: float + engaged_target_tag: int + buff_duration_remain: int + buff_duration_max: int + rally_targets: list[RallyTarget] + def __init__( + self, + display_type: DisplayType = ..., + alliance: Alliance = ..., + tag: int = ..., + unit_type: int = ..., + owner: int = ..., + pos: Point = ..., + facing: float = ..., + radius: float = ..., + build_progress: float = ..., + cloak: CloakState = ..., + buff_ids: list[int] = ..., + detect_range: float = ..., + radar_range: float = ..., + is_selected: bool = ..., + is_on_screen: bool = ..., + is_blip: bool = ..., + is_powered: bool = ..., + is_active: bool = ..., + attack_upgrade_level: int = ..., + armor_upgrade_level: int = ..., + shield_upgrade_level: int = ..., + health: float = ..., + health_max: float = ..., + shield: float = ..., + shield_max: float = ..., + energy: float = ..., + energy_max: float = ..., + mineral_contents: int = ..., + vespene_contents: int = ..., + is_flying: bool = ..., + is_burrowed: bool = ..., + is_hallucination: bool = ..., + orders: list[UnitOrder] = ..., + add_on_tag: int = ..., + passengers: list[PassengerUnit] = ..., + cargo_space_taken: int = ..., + cargo_space_max: int = ..., + assigned_harvesters: int = ..., + ideal_harvesters: int = ..., + weapon_cooldown: float = ..., + engaged_target_tag: int = ..., + buff_duration_remain: int = ..., + buff_duration_max: int = ..., + rally_targets: list[RallyTarget] = ..., + ) -> None: ... + +class MapState(Message): + visibility: ImageData + creep: ImageData + def __init__(self, visibility: ImageData = ..., creep: ImageData = ...) -> None: ... + +class Event(Message): + dead_units: list[int] + def __init__(self, dead_units: list[int] = ...) -> None: ... + +class Effect(Message): + effect_id: int + pos: list[Point2D] + alliance: Alliance + owner: int + radius: float + def __init__( + self, + effect_id: int = ..., + pos: list[Point2D] = ..., + alliance: Alliance = ..., + owner: int = ..., + radius: float = ..., + ) -> None: ... + +class ActionRaw(Message): + unit_command: ActionRawUnitCommand + camera_move: ActionRawCameraMove + toggle_autocast: ActionRawToggleAutocast + def __init__( + self, + unit_command: ActionRawUnitCommand = ..., + camera_move: ActionRawCameraMove = ..., + toggle_autocast: ActionRawToggleAutocast = ..., + ) -> None: ... + +class ActionRawUnitCommand(Message): + ability_id: int + target_world_space_pos: Point2D + target_unit_tag: int + unit_tags: list[int] + queue_command: bool + def __init__( + self, + ability_id: int = ..., + target_world_space_pos: Point2D = ..., + target_unit_tag: int = ..., + unit_tags: list[int] = ..., + queue_command: bool = ..., + ) -> None: ... + +class ActionRawCameraMove(Message): + center_world_space: Point + def __init__(self, center_world_space: Point = ...) -> None: ... + +class ActionRawToggleAutocast(Message): + ability_id: int + unit_tags: list[int] + def __init__(self, ability_id: int = ..., unit_tags: list[int] = ...) -> None: ... From 888a3b2a08ae374b7044ff89cda444b718a40931 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 03:10:15 +0100 Subject: [PATCH 09/34] Finalize score_pb2 --- s2clientprotocol/score_pb2.pyi | 105 +++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 s2clientprotocol/score_pb2.pyi diff --git a/s2clientprotocol/score_pb2.pyi b/s2clientprotocol/score_pb2.pyi new file mode 100644 index 00000000..61d54c01 --- /dev/null +++ b/s2clientprotocol/score_pb2.pyi @@ -0,0 +1,105 @@ +from __future__ import annotations +from enum import Enum +from google.protobuf.message import Message + +class ScoreType(Enum): + Curriculum: int + Melee: int + +class Score(Message): + score_type: ScoreType + score: int + score_details: ScoreDetails + def __init__( + self, + score_type: ScoreType = ..., + score: int = ..., + score_details: ScoreDetails = ..., + ) -> None: ... + +class CategoryScoreDetails(Message): + none: float + army: float + economy: float + technology: float + upgrade: float + def __init__( + self, + none: float = ..., + army: float = ..., + economy: float = ..., + technology: float = ..., + upgrade: float = ..., + ) -> None: ... + +class VitalScoreDetails(Message): + life: float + shields: float + energy: float + def __init__( + self, + life: float = ..., + shields: float = ..., + energy: float = ..., + ) -> None: ... + +class ScoreDetails(Message): + idle_production_time: float + idle_worker_time: float + total_value_units: float + total_value_structures: float + killed_value_units: float + killed_value_structures: float + collected_minerals: float + collected_vespene: float + collection_rate_minerals: float + collection_rate_vespene: float + spent_minerals: float + spent_vespene: float + food_used: CategoryScoreDetails + killed_minerals: CategoryScoreDetails + killed_vespene: CategoryScoreDetails + lost_minerals: CategoryScoreDetails + lost_vespene: CategoryScoreDetails + friendly_fire_minerals: CategoryScoreDetails + friendly_fire_vespene: CategoryScoreDetails + used_minerals: CategoryScoreDetails + used_vespene: CategoryScoreDetails + total_used_minerals: CategoryScoreDetails + total_used_vespene: CategoryScoreDetails + total_damage_dealt: VitalScoreDetails + total_damage_taken: VitalScoreDetails + total_healed: VitalScoreDetails + current_apm: float + current_effective_apm: float + def __init__( + self, + idle_production_time: float = ..., + idle_worker_time: float = ..., + total_value_units: float = ..., + total_value_structures: float = ..., + killed_value_units: float = ..., + killed_value_structures: float = ..., + collected_minerals: float = ..., + collected_vespene: float = ..., + collection_rate_minerals: float = ..., + collection_rate_vespene: float = ..., + spent_minerals: float = ..., + spent_vespene: float = ..., + food_used: CategoryScoreDetails = ..., + killed_minerals: CategoryScoreDetails = ..., + killed_vespene: CategoryScoreDetails = ..., + lost_minerals: CategoryScoreDetails = ..., + lost_vespene: CategoryScoreDetails = ..., + friendly_fire_minerals: CategoryScoreDetails = ..., + friendly_fire_vespene: CategoryScoreDetails = ..., + used_minerals: CategoryScoreDetails = ..., + used_vespene: CategoryScoreDetails = ..., + total_used_minerals: CategoryScoreDetails = ..., + total_used_vespene: CategoryScoreDetails = ..., + total_damage_dealt: VitalScoreDetails = ..., + total_damage_taken: VitalScoreDetails = ..., + total_healed: VitalScoreDetails = ..., + current_apm: float = ..., + current_effective_apm: float = ..., + ) -> None: ... From 4cf1920996f2a64f3c568c53cb181319befaba90 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 03:16:29 +0100 Subject: [PATCH 10/34] Finalize ui_pb2 --- s2clientprotocol/ui_pb2.pyi | 181 ++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 s2clientprotocol/ui_pb2.pyi diff --git a/s2clientprotocol/ui_pb2.pyi b/s2clientprotocol/ui_pb2.pyi new file mode 100644 index 00000000..4fb1507d --- /dev/null +++ b/s2clientprotocol/ui_pb2.pyi @@ -0,0 +1,181 @@ +from __future__ import annotations +from enum import Enum +from google.protobuf.message import Message + +class ObservationUI(Message): + groups: list[ControlGroup] + single: SinglePanel + multi: MultiPanel + cargo: CargoPanel + production: ProductionPanel + def __init__( + self, + groups: list[ControlGroup] = ..., + single: SinglePanel = ..., + multi: MultiPanel = ..., + cargo: CargoPanel = ..., + production: ProductionPanel = ..., + ) -> None: ... + +class ControlGroup(Message): + control_group_index: int + leader_unit_type: int + count: int + def __init__( + self, + control_group_index: int = ..., + leader_unit_type: int = ..., + count: int = ..., + ) -> None: ... + +class UnitInfo(Message): + unit_type: int + player_relative: int + health: int + shields: int + energy: int + transport_slots_taken: int + build_progress: float + add_on: UnitInfo + max_health: int + max_shields: int + max_energy: int + def __init__( + self, + unit_type: int = ..., + player_relative: int = ..., + health: int = ..., + shields: int = ..., + energy: int = ..., + transport_slots_taken: int = ..., + build_progress: float = ..., + add_on: UnitInfo = ..., + max_health: int = ..., + max_shields: int = ..., + max_energy: int = ..., + ) -> None: ... + +class SinglePanel(Message): + unit: UnitInfo + attack_upgrade_level: int + armor_upgrade_level: int + shield_upgrade_level: int + buffs: list[int] + def __init__( + self, + unit: UnitInfo = ..., + attack_upgrade_level: int = ..., + armor_upgrade_level: int = ..., + shield_upgrade_level: int = ..., + buffs: list[int] = ..., + ) -> None: ... + +class MultiPanel(Message): + units: list[UnitInfo] + def __init__(self, units: list[UnitInfo] = ...) -> None: ... + +class CargoPanel(Message): + unit: UnitInfo + passengers: list[UnitInfo] + slots_available: int + def __init__( + self, + unit: UnitInfo = ..., + passengers: list[UnitInfo] = ..., + slots_available: int = ..., + ) -> None: ... + +class BuildItem(Message): + ability_id: int + build_progress: float + def __init__(self, ability_id: int = ..., build_progress: float = ...) -> None: ... + +class ProductionPanel(Message): + unit: UnitInfo + build_queue: list[UnitInfo] + production_queue: list[BuildItem] + def __init__( + self, + unit: UnitInfo = ..., + build_queue: list[UnitInfo] = ..., + production_queue: list[BuildItem] = ..., + ) -> None: ... + +class ActionUI(Message): + control_group: ActionControlGroup + select_army: ActionSelectArmy + select_warp_gates: ActionSelectWarpGates + select_larva: ActionSelectLarva + select_idle_worker: ActionSelectIdleWorker + multi_panel: ActionMultiPanel + cargo_panel: ActionCargoPanelUnload + production_panel: ActionProductionPanelRemoveFromQueue + toggle_autocast: ActionToggleAutocast + def __init__( + self, + control_group: ActionControlGroup = ..., + select_army: ActionSelectArmy = ..., + select_warp_gates: ActionSelectWarpGates = ..., + select_larva: ActionSelectLarva = ..., + select_idle_worker: ActionSelectIdleWorker = ..., + multi_panel: ActionMultiPanel = ..., + cargo_panel: ActionCargoPanelUnload = ..., + production_panel: ActionProductionPanelRemoveFromQueue = ..., + toggle_autocast: ActionToggleAutocast = ..., + ) -> None: ... + +class ControlGroupAction(Enum): + Recall: int + Set: int + Append: int + SetAndSteal: int + AppendAndSteal: int + +class ActionControlGroup(Message): + action: ControlGroupAction + control_group_index: int + def __init__(self, action: ControlGroupAction = ..., control_group_index: int = ...) -> None: ... + +class ActionSelectArmy(Message): + selection_add: bool + def __init__(self, selection_add: bool = ...) -> None: ... + +class ActionSelectWarpGates(Message): + selection_add: bool + def __init__(self, selection_add: bool = ...) -> None: ... + +class ActionSelectLarva(Message): + def __init__(self) -> None: ... + +class ActionSelectIdleWorker(Message): + class Type(Enum): + Set: int + Add: int + All: int + AddAll: int + + type: Type + def __init__(self, type: Type = ...) -> None: ... + +class ActionMultiPanel(Message): + class Type(Enum): + SingleSelect: int + DeselectUnit: int + SelectAllOfType: int + DeselectAllOfType: int + + type: Type + unit_index: int + def __init__(self, type: Type = ..., unit_index: int = ...) -> None: ... + +class ActionCargoPanelUnload(Message): + unit_index: int + def __init__(self, unit_index: int = ...) -> None: ... + +class ActionProductionPanelRemoveFromQueue(Message): + unit_index: int + def __init__(self, unit_index: int = ...) -> None: ... + +class ActionToggleAutocast(Message): + ability_id: int + def __init__(self, ability_id: int = ...) -> None: ... From ee5b33151efd89ae380e429bf2a412c49fb13abd Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 03:37:51 +0100 Subject: [PATCH 11/34] Finalize sc2api_pb2 --- s2clientprotocol/sc2api_pb2.pyi | 747 ++++++++++++++++++++++++++++++++ 1 file changed, 747 insertions(+) create mode 100644 s2clientprotocol/sc2api_pb2.pyi diff --git a/s2clientprotocol/sc2api_pb2.pyi b/s2clientprotocol/sc2api_pb2.pyi new file mode 100644 index 00000000..10b271ef --- /dev/null +++ b/s2clientprotocol/sc2api_pb2.pyi @@ -0,0 +1,747 @@ +from __future__ import annotations + +from enum import Enum +from google.protobuf.message import Message + +from s2clientprotocol.spatial_pb2 import ActionSpatial, ObservationFeatureLayer, ObservationRender +from .common_pb2 import Race, Point2D, Size2DI, AvailableAbility +from .data_pb2 import AbilityData, UnitTypeData, UpgradeData, BuffData, EffectData +from .debug_pb2 import DebugCommand +from .error_pb2 import ActionResult +from .query_pb2 import RequestQuery, ResponseQuery +from .raw_pb2 import ActionRaw, ObservationRaw, StartRaw +from .ui_pb2 import ActionUI, ObservationUI + +from .score_pb2 import Score + +class Request(Message): + create_game: RequestCreateGame + join_game: RequestJoinGame + restart_game: RequestRestartGame + start_replay: RequestStartReplay + leave_game: RequestLeaveGame + quick_save: RequestQuickSave + quick_load: RequestQuickLoad + quit: RequestQuit + game_info: RequestGameInfo + observation: RequestObservation + action: RequestAction + obs_action: RequestObserverAction + step: RequestStep + data: RequestData + query: RequestQuery + save_replay: RequestSaveReplay + map_command: RequestMapCommand + replay_info: RequestReplayInfo + available_maps: RequestAvailableMaps + save_map: RequestSaveMap + ping: RequestPing + debug: RequestDebug + id: int + def __init__( + self, + create_game: RequestCreateGame = ..., + join_game: RequestJoinGame = ..., + restart_game: RequestRestartGame = ..., + start_replay: RequestStartReplay = ..., + leave_game: RequestLeaveGame = ..., + quick_save: RequestQuickSave = ..., + quick_load: RequestQuickLoad = ..., + quit: RequestQuit = ..., + game_info: RequestGameInfo = ..., + observation: RequestObservation = ..., + action: RequestAction = ..., + obs_action: RequestObserverAction = ..., + step: RequestStep = ..., + data: RequestData = ..., + query: RequestQuery = ..., + save_replay: RequestSaveReplay = ..., + map_command: RequestMapCommand = ..., + replay_info: RequestReplayInfo = ..., + available_maps: RequestAvailableMaps = ..., + save_map: RequestSaveMap = ..., + ping: RequestPing = ..., + debug: RequestDebug = ..., + id: int = ..., + ) -> None: ... + +class Response(Message): + create_game: ResponseCreateGame + join_game: ResponseJoinGame + restart_game: ResponseRestartGame + start_replay: ResponseStartReplay + leave_game: ResponseLeaveGame + quick_save: ResponseQuickSave + quick_load: ResponseQuickLoad + quit: ResponseQuit + game_info: ResponseGameInfo + observation: ResponseObservation + action: ResponseAction + obs_action: ResponseObserverAction + step: ResponseStep + data: ResponseData + query: ResponseQuery + save_replay: ResponseSaveReplay + replay_info: ResponseReplayInfo + available_maps: ResponseAvailableMaps + save_map: ResponseSaveMap + map_command: ResponseMapCommand + ping: ResponsePing + debug: ResponseDebug + id: int + error: list[str] + status: Status + def __init__( + self, + create_game: ResponseCreateGame = ..., + join_game: ResponseJoinGame = ..., + restart_game: ResponseRestartGame = ..., + start_replay: ResponseStartReplay = ..., + leave_game: ResponseLeaveGame = ..., + quick_save: ResponseQuickSave = ..., + quick_load: ResponseQuickLoad = ..., + quit: ResponseQuit = ..., + game_info: ResponseGameInfo = ..., + observation: ResponseObservation = ..., + action: ResponseAction = ..., + obs_action: ResponseObserverAction = ..., + step: ResponseStep = ..., + data: ResponseData = ..., + query: ResponseQuery = ..., + save_replay: ResponseSaveReplay = ..., + replay_info: ResponseReplayInfo = ..., + available_maps: ResponseAvailableMaps = ..., + save_map: ResponseSaveMap = ..., + map_command: ResponseMapCommand = ..., + ping: ResponsePing = ..., + debug: ResponseDebug = ..., + id: int = ..., + error: list[str] = ..., + status: Status = ..., + ) -> None: ... + +class Status(Enum): + launched: int + init_game: int + in_game: int + in_replay: int + ended: int + quit: int + unknown: int + +class RequestCreateGame(Message): + local_map: LocalMap + battlenet_map_name: str + player_setup: list[PlayerSetup] + disable_fog: bool + random_seed: int + realtime: bool + def __init__( + self, + local_map: LocalMap = ..., + battlenet_map_name: str = ..., + player_setup: list[PlayerSetup] = ..., + disable_fog: bool = ..., + random_seed: int = ..., + realtime: bool = ..., + ) -> None: ... + +class LocalMap(Message): + map_path: str + map_data: bytes + def __init__(self, map_path: str = ..., map_data: bytes = ...) -> None: ... + +class ResponseCreateGame(Message): + class Error(Enum): + MissingMap: int + InvalidMapPath: int + InvalidMapData: int + InvalidMapName: int + InvalidMapHandle: int + MissingPlayerSetup: int + InvalidPlayerSetup: int + MultiplayerUnsupported: int + + error: Error + error_details: str + def __init__(self, error: Error = ..., error_details: str = ...) -> None: ... + +class RequestJoinGame(Message): + race: Race + observed_player_id: int + options: InterfaceOptions + server_ports: PortSet + client_ports: list[PortSet] + shared_port: int + player_name: str + host_ip: str + def __init__( + self, + race: Race = ..., + observed_player_id: int = ..., + options: InterfaceOptions = ..., + server_ports: PortSet = ..., + client_ports: list[PortSet] = ..., + shared_port: int = ..., + player_name: str = ..., + host_ip: str = ..., + ) -> None: ... + +class PortSet(Message): + game_port: int + base_port: int + def __init__(self, game_port: int = ..., base_port: int = ...) -> None: ... + +class ResponseJoinGame(Message): + class Error(Enum): + MissingParticipation: int + InvalidObservedPlayerId: int + MissingOptions: int + MissingPorts: int + GameFull: int + LaunchError: int + FeatureUnsupported: int + NoSpaceForUser: int + MapDoesNotExist: int + CannotOpenMap: int + ChecksumError: int + NetworkError: int + OtherError: int + + player_id: int + error: Error + error_details: str + def __init__(self, player_id: int = ..., error: Error = ..., error_details: str = ...) -> None: ... + +class RequestRestartGame(Message): + def __init__(self) -> None: ... + +class ResponseRestartGame(Message): + class Error(Enum): + LaunchError: int + + error: Error + error_details: str + need_hard_reset: bool + def __init__(self, error: Error = ..., error_details: str = ..., need_hard_reset: bool = ...) -> None: ... + +class RequestStartReplay(Message): + replay_path: str + replay_data: bytes + map_data: bytes + observed_player_id: int + options: InterfaceOptions + disable_fog: bool + realtime: bool + record_replay: bool + def __init__( + self, + replay_path: str = ..., + replay_data: bytes = ..., + map_data: bytes = ..., + observed_player_id: int = ..., + options: InterfaceOptions = ..., + disable_fog: bool = ..., + realtime: bool = ..., + record_replay: bool = ..., + ) -> None: ... + +class ResponseStartReplay(Message): + class Error(Enum): + MissingReplay: int + InvalidReplayPath: int + InvalidReplayData: int + InvalidMapData: int + InvalidObservedPlayerId: int + MissingOptions: int + LaunchError: int + + error: Error + error_details: str + def __init__(self, error: Error = ..., error_details: str = ...) -> None: ... + +class RequestMapCommand(Message): + trigger_cmd: str + def __init__(self, trigger_cmd: str = ...) -> None: ... + +class ResponseMapCommand(Message): + class Error(Enum): + NoTriggerError: int + + error: Error + error_details: str + def __init__(self, error: Error = ..., error_details: str = ...) -> None: ... + +class RequestLeaveGame(Message): + def __init__(self) -> None: ... + +class ResponseLeaveGame(Message): + def __init__(self) -> None: ... + +class RequestQuickSave(Message): + def __init__(self) -> None: ... + +class ResponseQuickSave(Message): + def __init__(self) -> None: ... + +class RequestQuickLoad(Message): + def __init__(self) -> None: ... + +class ResponseQuickLoad(Message): + def __init__(self) -> None: ... + +class RequestQuit(Message): + def __init__(self) -> None: ... + +class ResponseQuit(Message): + def __init__(self) -> None: ... + +class RequestGameInfo(Message): + def __init__(self) -> None: ... + +class ResponseGameInfo(Message): + map_name: str + mod_names: list[str] + local_map_path: str + player_info: list[PlayerInfo] + start_raw: StartRaw + options: InterfaceOptions + def __init__( + self, + map_name: str = ..., + mod_names: list[str] = ..., + local_map_path: str = ..., + player_info: list[PlayerInfo] = ..., + start_raw: StartRaw = ..., + options: InterfaceOptions = ..., + ) -> None: ... + +class RequestObservation(Message): + disable_fog: bool + game_loop: int + def __init__(self, disable_fog: bool = ..., game_loop: int = ...) -> None: ... + +class ResponseObservation(Message): + actions: list[Action] + action_errors: list[ActionError] + observation: Observation + player_result: list[PlayerResult] + chat: list[ChatReceived] + def __init__( + self, + actions: list[Action] = ..., + action_errors: list[ActionError] = ..., + observation: Observation = ..., + player_result: list[PlayerResult] = ..., + chat: list[ChatReceived] = ..., + ) -> None: ... + +class ChatReceived(Message): + player_id: int + message: str + def __init__(self, player_id: int = ..., message: str = ...) -> None: ... + +class RequestAction(Message): + actions: list[Action] + def __init__(self, actions: list[Action] = ...) -> None: ... + +class ResponseAction(Message): + result: list[ActionResult] + def __init__(self, result: list[ActionResult] = ...) -> None: ... + +class RequestObserverAction(Message): + actions: list[ObserverAction] + def __init__(self, actions: list[ObserverAction] = ...) -> None: ... + +class ResponseObserverAction(Message): + def __init__(self) -> None: ... + +class RequestStep(Message): + count: int + def __init__(self, count: int = ...) -> None: ... + +class ResponseStep(Message): + simulation_loop: int + def __init__(self, simulation_loop: int = ...) -> None: ... + +class RequestData(Message): + ability_id: bool + unit_type_id: bool + upgrade_id: bool + buff_id: bool + effect_id: bool + def __init__( + self, + ability_id: bool = ..., + unit_type_id: bool = ..., + upgrade_id: bool = ..., + buff_id: bool = ..., + effect_id: bool = ..., + ) -> None: ... + +class ResponseData(Message): + abilities: list[AbilityData] + units: list[UnitTypeData] + upgrades: list[UpgradeData] + buffs: list[BuffData] + effects: list[EffectData] + def __init__( + self, + abilities: list[AbilityData] = ..., + units: list[UnitTypeData] = ..., + upgrades: list[UpgradeData] = ..., + buffs: list[BuffData] = ..., + effects: list[EffectData] = ..., + ) -> None: ... + +class RequestSaveReplay(Message): + def __init__(self) -> None: ... + +class ResponseSaveReplay(Message): + data: bytes + def __init__(self, data: bytes = ...) -> None: ... + +class RequestReplayInfo(Message): + replay_path: str + replay_data: bytes + download_data: bool + def __init__( + self, + replay_path: str = ..., + replay_data: bytes = ..., + download_data: bool = ..., + ) -> None: ... + +class PlayerInfoExtra(Message): + player_info: PlayerInfo + player_result: PlayerResult + player_mmr: int + player_apm: int + def __init__( + self, + player_info: PlayerInfo = ..., + player_result: PlayerResult = ..., + player_mmr: int = ..., + player_apm: int = ..., + ) -> None: ... + +class ResponseReplayInfo(Message): + class Error(Enum): + MissingReplay: int + InvalidReplayPath: int + InvalidReplayData: int + ParsingError: int + DownloadError: int + + map_name: str + local_map_path: str + player_info: list[PlayerInfoExtra] + game_duration_loops: int + game_duration_seconds: float + game_version: str + data_version: str + data_build: int + base_build: int + error: Error + error_details: str + def __init__( + self, + map_name: str = ..., + local_map_path: str = ..., + player_info: list[PlayerInfoExtra] = ..., + game_duration_loops: int = ..., + game_duration_seconds: float = ..., + game_version: str = ..., + data_version: str = ..., + data_build: int = ..., + base_build: int = ..., + error: Error = ..., + error_details: str = ..., + ) -> None: ... + +class RequestAvailableMaps(Message): + def __init__(self) -> None: ... + +class ResponseAvailableMaps(Message): + local_map_paths: list[str] + battlenet_map_names: list[str] + def __init__(self, local_map_paths: list[str] = ..., battlenet_map_names: list[str] = ...) -> None: ... + +class RequestSaveMap(Message): + map_path: str + map_data: bytes + def __init__(self, map_path: str = ..., map_data: bytes = ...) -> None: ... + +class ResponseSaveMap(Message): + class Error(Enum): + InvalidMapData: int + + error: Error + def __init__(self, error: Error = ...) -> None: ... + +class RequestPing(Message): + def __init__(self) -> None: ... + +class ResponsePing(Message): + game_version: str + data_version: str + data_build: int + base_build: int + def __init__( + self, + game_version: str = ..., + data_version: str = ..., + data_build: int = ..., + base_build: int = ..., + ) -> None: ... + +class RequestDebug(Message): + debug: list[DebugCommand] + def __init__(self, debug: list[DebugCommand] = ...) -> None: ... + +class ResponseDebug(Message): + def __init__(self) -> None: ... + +class Difficulty(Enum): + VeryEasy: int + Easy: int + Medium: int + MediumHard: int + Hard: int + Harder: int + VeryHard: int + CheatVision: int + CheatMoney: int + CheatInsane: int + +class PlayerType(Enum): + Participant: int + Computer: int + Observer: int + +class AIBuild(Enum): + RandomBuild: int + Rush: int + Timing: int + Power: int + Macro: int + Air: int + +class PlayerSetup(Message): + type: PlayerType + race: Race + difficulty: Difficulty + player_name: str + ai_build: AIBuild + def __init__( + self, + type: PlayerType = ..., + race: Race = ..., + difficulty: Difficulty = ..., + player_name: str = ..., + ai_build: AIBuild = ..., + ) -> None: ... + +class SpatialCameraSetup(Message): + resolution: Size2DI + minimap_resolution: Size2DI + width: float + crop_to_playable_area: bool + allow_cheating_layers: bool + def __init__( + self, + resolution: Size2DI = ..., + minimap_resolution: Size2DI = ..., + width: float = ..., + crop_to_playable_area: bool = ..., + allow_cheating_layers: bool = ..., + ) -> None: ... + +class InterfaceOptions(Message): + raw: bool + score: bool + feature_layer: SpatialCameraSetup + render: SpatialCameraSetup + show_cloaked: bool + show_burrowed_shadows: bool + show_placeholders: bool + raw_affects_selection: bool + raw_crop_to_playable_area: bool + def __init__( + self, + raw: bool = ..., + score: bool = ..., + feature_layer: SpatialCameraSetup = ..., + render: SpatialCameraSetup = ..., + show_cloaked: bool = ..., + show_burrowed_shadows: bool = ..., + show_placeholders: bool = ..., + raw_affects_selection: bool = ..., + raw_crop_to_playable_area: bool = ..., + ) -> None: ... + +class PlayerInfo(Message): + player_id: int + type: PlayerType + race_requested: Race + race_actual: Race + difficulty: Difficulty + ai_build: AIBuild + player_name: str + def __init__( + self, + player_id: int = ..., + type: PlayerType = ..., + race_requested: Race = ..., + race_actual: Race = ..., + difficulty: Difficulty = ..., + ai_build: AIBuild = ..., + player_name: str = ..., + ) -> None: ... + +class PlayerCommon(Message): + player_id: int + minerals: int + vespene: int + food_cap: int + food_used: int + food_army: int + food_workers: int + idle_worker_count: int + army_count: int + warp_gate_count: int + larva_count: int + def __init__( + self, + player_id: int = ..., + minerals: int = ..., + vespene: int = ..., + food_cap: int = ..., + food_used: int = ..., + food_army: int = ..., + food_workers: int = ..., + idle_worker_count: int = ..., + army_count: int = ..., + warp_gate_count: int = ..., + larva_count: int = ..., + ) -> None: ... + +class Observation(Message): + game_loop: int + player_common: PlayerCommon + alerts: list[Alert] + abilities: list[AvailableAbility] + score: Score + raw_data: ObservationRaw + feature_layer_data: ObservationFeatureLayer + render_data: ObservationRender + ui_data: ObservationUI + def __init__( + self, + game_loop: int = ..., + player_common: PlayerCommon = ..., + alerts: list[Alert] = ..., + abilities: list[AvailableAbility] = ..., + score: Score = ..., + raw_data: ObservationRaw = ..., + feature_layer_data: ObservationFeatureLayer = ..., + render_data: ObservationRender = ..., + ui_data: ObservationUI = ..., + ) -> None: ... + +class Action(Message): + action_raw: ActionRaw + action_feature_layer: ActionSpatial + action_render: ActionSpatial + action_ui: ActionUI + action_chat: ActionChat + game_loop: int + def __init__( + self, + action_raw: ActionRaw = ..., + action_feature_layer: ActionSpatial = ..., + action_render: ActionSpatial = ..., + action_ui: ActionUI = ..., + action_chat: ActionChat = ..., + game_loop: int = ..., + ) -> None: ... + +class Channel(Enum): + Broadcast: int + Team: int + +class ActionChat(Message): + channel: Channel + message: str + def __init__(self, channel: Channel = ..., message: str = ...) -> None: ... + +class ActionError(Message): + unit_tag: int + ability_id: int + result: ActionResult + def __init__(self, unit_tag: int = ..., ability_id: int = ..., result: ActionResult = ...) -> None: ... + +class ObserverAction(Message): + player_perspective: ActionObserverPlayerPerspective + camera_move: ActionObserverCameraMove + camera_follow_player: ActionObserverCameraFollowPlayer + camera_follow_units: ActionObserverCameraFollowUnits + def __init__( + self, + player_perspective: ActionObserverPlayerPerspective = ..., + camera_move: ActionObserverCameraMove = ..., + camera_follow_player: ActionObserverCameraFollowPlayer = ..., + camera_follow_units: ActionObserverCameraFollowUnits = ..., + ) -> None: ... + +class ActionObserverPlayerPerspective(Message): + player_id: int + def __init__(self, player_id: int = ...) -> None: ... + +class ActionObserverCameraMove(Message): + world_pos: Point2D + distance: float + def __init__(self, world_pos: Point2D = ..., distance: float = ...) -> None: ... + +class ActionObserverCameraFollowPlayer(Message): + player_id: int + def __init__(self, player_id: int = ...) -> None: ... + +class ActionObserverCameraFollowUnits(Message): + unit_tags: list[int] + def __init__(self, unit_tags: list[int] = ...) -> None: ... + +class Alert(Enum): + AlertError: int + AddOnComplete: int + BuildingComplete: int + BuildingUnderAttack: int + LarvaHatched: int + MergeComplete: int + MineralsExhausted: int + MorphComplete: int + MothershipComplete: int + MULEExpired: int + NuclearLaunchDetected: int + NukeComplete: int + NydusWormDetected: int + ResearchComplete: int + TrainError: int + TrainUnitComplete: int + TrainWorkerComplete: int + TransformationComplete: int + UnitUnderAttack: int + UpgradeComplete: int + VespeneExhausted: int + WarpInComplete: int + +class Result(Enum): + Victory: int + Defeat: int + Tie: int + Undecided: int + +class PlayerResult(Message): + player_id: int + result: Result + def __init__(self, player_id: int = ..., result: Result = ...) -> None: ... From f13299d9107dc176c7dddfbfd6392987d5b28fb1 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 03:39:15 +0100 Subject: [PATCH 12/34] Finalize spatial_pb2 --- s2clientprotocol/spatial_pb2.pyi | 151 +++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 s2clientprotocol/spatial_pb2.pyi diff --git a/s2clientprotocol/spatial_pb2.pyi b/s2clientprotocol/spatial_pb2.pyi new file mode 100644 index 00000000..a785470e --- /dev/null +++ b/s2clientprotocol/spatial_pb2.pyi @@ -0,0 +1,151 @@ +from __future__ import annotations + +from enum import Enum +from google.protobuf.message import Message +from .common_pb2 import ImageData, PointI, RectangleI + +class ObservationFeatureLayer(Message): + renders: FeatureLayers + minimap_renders: FeatureLayersMinimap + def __init__( + self, + renders: FeatureLayers = ..., + minimap_renders: FeatureLayersMinimap = ..., + ) -> None: ... + +class FeatureLayers(Message): + height_map: ImageData + visibility_map: ImageData + creep: ImageData + power: ImageData + player_id: ImageData + unit_type: ImageData + selected: ImageData + unit_hit_points: ImageData + unit_hit_points_ratio: ImageData + unit_energy: ImageData + unit_energy_ratio: ImageData + unit_shields: ImageData + unit_shields_ratio: ImageData + player_relative: ImageData + unit_density_aa: ImageData + unit_density: ImageData + effects: ImageData + hallucinations: ImageData + cloaked: ImageData + blip: ImageData + buffs: ImageData + buff_duration: ImageData + active: ImageData + build_progress: ImageData + buildable: ImageData + pathable: ImageData + placeholder: ImageData + def __init__( + self, + height_map: ImageData = ..., + visibility_map: ImageData = ..., + creep: ImageData = ..., + power: ImageData = ..., + player_id: ImageData = ..., + unit_type: ImageData = ..., + selected: ImageData = ..., + unit_hit_points: ImageData = ..., + unit_hit_points_ratio: ImageData = ..., + unit_energy: ImageData = ..., + unit_energy_ratio: ImageData = ..., + unit_shields: ImageData = ..., + unit_shields_ratio: ImageData = ..., + player_relative: ImageData = ..., + unit_density_aa: ImageData = ..., + unit_density: ImageData = ..., + effects: ImageData = ..., + hallucinations: ImageData = ..., + cloaked: ImageData = ..., + blip: ImageData = ..., + buffs: ImageData = ..., + buff_duration: ImageData = ..., + active: ImageData = ..., + build_progress: ImageData = ..., + buildable: ImageData = ..., + pathable: ImageData = ..., + placeholder: ImageData = ..., + ) -> None: ... + +class FeatureLayersMinimap(Message): + height_map: ImageData + visibility_map: ImageData + creep: ImageData + camera: ImageData + player_id: ImageData + player_relative: ImageData + selected: ImageData + alerts: ImageData + buildable: ImageData + pathable: ImageData + unit_type: ImageData + def __init__( + self, + height_map: ImageData = ..., + visibility_map: ImageData = ..., + creep: ImageData = ..., + camera: ImageData = ..., + player_id: ImageData = ..., + player_relative: ImageData = ..., + selected: ImageData = ..., + alerts: ImageData = ..., + buildable: ImageData = ..., + pathable: ImageData = ..., + unit_type: ImageData = ..., + ) -> None: ... + +class ObservationRender(Message): + map: ImageData + minimap: ImageData + def __init__(self, map: ImageData = ..., minimap: ImageData = ...) -> None: ... + +class ActionSpatial(Message): + unit_command: ActionSpatialUnitCommand + camera_move: ActionSpatialCameraMove + unit_selection_point: ActionSpatialUnitSelectionPoint + unit_selection_rect: ActionSpatialUnitSelectionRect + def __init__( + self, + unit_command: ActionSpatialUnitCommand = ..., + camera_move: ActionSpatialCameraMove = ..., + unit_selection_point: ActionSpatialUnitSelectionPoint = ..., + unit_selection_rect: ActionSpatialUnitSelectionRect = ..., + ) -> None: ... + +class ActionSpatialUnitCommand(Message): + ability_id: int + target_screen_coord: PointI + target_minimap_coord: PointI + queue_command: bool + def __init__( + self, + ability_id: int = ..., + target_screen_coord: PointI = ..., + target_minimap_coord: PointI = ..., + queue_command: bool = ..., + ) -> None: ... + +class ActionSpatialCameraMove(Message): + center_minimap: PointI + def __init__(self, center_minimap: PointI = ...) -> None: ... + +class Type(Enum): + Select: int + Toggle: int + AllType: int + AddAllType: int + +class ActionSpatialUnitSelectionPoint(Message): + selection_screen_coord: PointI + type: Type + def __init__(self, selection_screen_coord: PointI = ..., type: Type = ...) -> None: ... + +class ActionSpatialUnitSelectionRect(Message): + selection_screen_coord: list[RectangleI] + selection_add: bool + def __init__(self, selection_screen_coord: list[RectangleI] = ..., selection_add: bool = ...) -> None: ... From d442bac9ce7f5f185e4e9a46156112a3ed0b5ff1 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 14:09:49 +0100 Subject: [PATCH 13/34] Adjust enums to be of type int --- s2clientprotocol/data_pb2.pyi | 17 +++--- s2clientprotocol/debug_pb2.pyi | 16 +++--- s2clientprotocol/query_pb2.pyi | 5 +- s2clientprotocol/raw_pb2.pyi | 16 +++--- s2clientprotocol/sc2api_pb2.pyi | 95 ++++++++++++++++---------------- s2clientprotocol/score_pb2.pyi | 4 +- s2clientprotocol/spatial_pb2.pyi | 4 +- s2clientprotocol/ui_pb2.pyi | 12 ++-- 8 files changed, 83 insertions(+), 86 deletions(-) diff --git a/s2clientprotocol/data_pb2.pyi b/s2clientprotocol/data_pb2.pyi index 2ba8f190..42d9410a 100644 --- a/s2clientprotocol/data_pb2.pyi +++ b/s2clientprotocol/data_pb2.pyi @@ -1,6 +1,5 @@ from enum import Enum from google.protobuf.message import Message -from .common_pb2 import Race class Target(Enum): # NONE: int @@ -18,7 +17,7 @@ class AbilityData(Message): hotkey: str remaps_to_ability_id: int available: bool - target: Target + target: int allow_minimap: bool allow_autocast: bool is_building: bool @@ -35,7 +34,7 @@ class AbilityData(Message): hotkey: str = ..., remaps_to_ability_id: int = ..., available: bool = ..., - target: Target = ..., + target: int = ..., allow_minimap: bool = ..., allow_autocast: bool = ..., is_building: bool = ..., @@ -58,9 +57,9 @@ class Attribute(Enum): Summoned: int class DamageBonus(Message): - attribute: Attribute + attribute: int bonus: float - def __init__(self, attribute: Attribute = ..., bonus: float = ...) -> None: ... + def __init__(self, attribute: int = ..., bonus: float = ...) -> None: ... class TargetType(Enum): Ground: int @@ -68,7 +67,7 @@ class TargetType(Enum): Any: int class Weapon(Message): - type: TargetType + type: int damage: float damage_bonus: list[DamageBonus] attacks: int @@ -76,7 +75,7 @@ class Weapon(Message): speed: float def __init__( self, - type: TargetType = ..., + type: int = ..., damage: float = ..., damage_bonus: list[DamageBonus] = ..., attacks: int = ..., @@ -94,7 +93,7 @@ class UnitTypeData(Message): food_required: float food_provided: float ability_id: int - race: Race + race: int build_time: float has_vespene: bool has_minerals: bool @@ -118,7 +117,7 @@ class UnitTypeData(Message): food_required: float = ..., food_provided: float = ..., ability_id: int = ..., - race: Race = ..., + race: int = ..., build_time: float = ..., has_vespene: bool = ..., has_minerals: bool = ..., diff --git a/s2clientprotocol/debug_pb2.pyi b/s2clientprotocol/debug_pb2.pyi index ea6617fe..ef8de239 100644 --- a/s2clientprotocol/debug_pb2.pyi +++ b/s2clientprotocol/debug_pb2.pyi @@ -4,7 +4,7 @@ from .common_pb2 import Point, Point2D class DebugCommand(Message): draw: DebugDraw - game_state: DebugGameState + game_state: int create_unit: DebugCreateUnit kill_unit: DebugKillUnit test_process: DebugTestProcess @@ -14,7 +14,7 @@ class DebugCommand(Message): def __init__( self, draw: DebugDraw = ..., - game_state: DebugGameState = ..., + game_state: int = ..., create_unit: DebugCreateUnit = ..., kill_unit: DebugKillUnit = ..., test_process: DebugTestProcess = ..., @@ -116,9 +116,9 @@ class Test(Enum): exit: int class DebugTestProcess(Message): - test: Test + test: int delay_ms: int - def __init__(self, test: Test = ..., delay_ms: int = ...) -> None: ... + def __init__(self, test: int = ..., delay_ms: int = ...) -> None: ... class DebugSetScore(Message): score: float @@ -129,8 +129,8 @@ class EndResult(Enum): DeclareVictory: int class DebugEndGame(Message): - end_result: EndResult - def __init__(self, end_result: EndResult = ...) -> None: ... + end_result: int + def __init__(self, end_result: int = ...) -> None: ... class UnitValue(Enum): Energy: int @@ -138,12 +138,12 @@ class UnitValue(Enum): Shields: int class DebugSetUnitValue(Message): - unit_value: UnitValue + unit_value: int value: float unit_tag: int def __init__( self, - unit_value: UnitValue = ..., + unit_value: int = ..., value: float = ..., unit_tag: int = ..., ) -> None: ... diff --git a/s2clientprotocol/query_pb2.pyi b/s2clientprotocol/query_pb2.pyi index dfd7823b..047fcd46 100644 --- a/s2clientprotocol/query_pb2.pyi +++ b/s2clientprotocol/query_pb2.pyi @@ -1,6 +1,5 @@ from google.protobuf.message import Message from .common_pb2 import Point2D, AvailableAbility -from .error_pb2 import ActionResult class RequestQuery(Message): pathing: list[RequestQueryPathing] @@ -68,5 +67,5 @@ class RequestQueryBuildingPlacement(Message): ) -> None: ... class ResponseQueryBuildingPlacement(Message): - result: ActionResult - def __init__(self, result: ActionResult = ...) -> None: ... + result: int + def __init__(self, result: int = ...) -> None: ... diff --git a/s2clientprotocol/raw_pb2.pyi b/s2clientprotocol/raw_pb2.pyi index ac4a7090..0a31096f 100644 --- a/s2clientprotocol/raw_pb2.pyi +++ b/s2clientprotocol/raw_pb2.pyi @@ -117,8 +117,8 @@ class RallyTarget(Message): def __init__(self, point: Point = ..., tag: int = ...) -> None: ... class Unit(Message): - display_type: DisplayType - alliance: Alliance + display_type: int + alliance: int tag: int unit_type: int owner: int @@ -126,7 +126,7 @@ class Unit(Message): facing: float radius: float build_progress: float - cloak: CloakState + cloak: int buff_ids: list[int] detect_range: float radar_range: float @@ -163,8 +163,8 @@ class Unit(Message): rally_targets: list[RallyTarget] def __init__( self, - display_type: DisplayType = ..., - alliance: Alliance = ..., + display_type: int = ..., + alliance: int = ..., tag: int = ..., unit_type: int = ..., owner: int = ..., @@ -172,7 +172,7 @@ class Unit(Message): facing: float = ..., radius: float = ..., build_progress: float = ..., - cloak: CloakState = ..., + cloak: int = ..., buff_ids: list[int] = ..., detect_range: float = ..., radar_range: float = ..., @@ -221,14 +221,14 @@ class Event(Message): class Effect(Message): effect_id: int pos: list[Point2D] - alliance: Alliance + alliance: int owner: int radius: float def __init__( self, effect_id: int = ..., pos: list[Point2D] = ..., - alliance: Alliance = ..., + alliance: int = ..., owner: int = ..., radius: float = ..., ) -> None: ... diff --git a/s2clientprotocol/sc2api_pb2.pyi b/s2clientprotocol/sc2api_pb2.pyi index 10b271ef..d1f7a552 100644 --- a/s2clientprotocol/sc2api_pb2.pyi +++ b/s2clientprotocol/sc2api_pb2.pyi @@ -4,10 +4,9 @@ from enum import Enum from google.protobuf.message import Message from s2clientprotocol.spatial_pb2 import ActionSpatial, ObservationFeatureLayer, ObservationRender -from .common_pb2 import Race, Point2D, Size2DI, AvailableAbility +from .common_pb2 import Point2D, Size2DI, AvailableAbility from .data_pb2 import AbilityData, UnitTypeData, UpgradeData, BuffData, EffectData from .debug_pb2 import DebugCommand -from .error_pb2 import ActionResult from .query_pb2 import RequestQuery, ResponseQuery from .raw_pb2 import ActionRaw, ObservationRaw, StartRaw from .ui_pb2 import ActionUI, ObservationUI @@ -90,7 +89,7 @@ class Response(Message): debug: ResponseDebug id: int error: list[str] - status: Status + status: int def __init__( self, create_game: ResponseCreateGame = ..., @@ -117,7 +116,7 @@ class Response(Message): debug: ResponseDebug = ..., id: int = ..., error: list[str] = ..., - status: Status = ..., + status: int = ..., ) -> None: ... class Status(Enum): @@ -162,12 +161,12 @@ class ResponseCreateGame(Message): InvalidPlayerSetup: int MultiplayerUnsupported: int - error: Error + error: int error_details: str - def __init__(self, error: Error = ..., error_details: str = ...) -> None: ... + def __init__(self, error: int = ..., error_details: str = ...) -> None: ... class RequestJoinGame(Message): - race: Race + race: int observed_player_id: int options: InterfaceOptions server_ports: PortSet @@ -177,7 +176,7 @@ class RequestJoinGame(Message): host_ip: str def __init__( self, - race: Race = ..., + race: int = ..., observed_player_id: int = ..., options: InterfaceOptions = ..., server_ports: PortSet = ..., @@ -209,9 +208,9 @@ class ResponseJoinGame(Message): OtherError: int player_id: int - error: Error + error: int error_details: str - def __init__(self, player_id: int = ..., error: Error = ..., error_details: str = ...) -> None: ... + def __init__(self, player_id: int = ..., error: int = ..., error_details: str = ...) -> None: ... class RequestRestartGame(Message): def __init__(self) -> None: ... @@ -220,10 +219,10 @@ class ResponseRestartGame(Message): class Error(Enum): LaunchError: int - error: Error + error: int error_details: str need_hard_reset: bool - def __init__(self, error: Error = ..., error_details: str = ..., need_hard_reset: bool = ...) -> None: ... + def __init__(self, error: int = ..., error_details: str = ..., need_hard_reset: bool = ...) -> None: ... class RequestStartReplay(Message): replay_path: str @@ -256,9 +255,9 @@ class ResponseStartReplay(Message): MissingOptions: int LaunchError: int - error: Error + error: int error_details: str - def __init__(self, error: Error = ..., error_details: str = ...) -> None: ... + def __init__(self, error: int = ..., error_details: str = ...) -> None: ... class RequestMapCommand(Message): trigger_cmd: str @@ -268,9 +267,9 @@ class ResponseMapCommand(Message): class Error(Enum): NoTriggerError: int - error: Error + error: int error_details: str - def __init__(self, error: Error = ..., error_details: str = ...) -> None: ... + def __init__(self, error: int = ..., error_details: str = ...) -> None: ... class RequestLeaveGame(Message): def __init__(self) -> None: ... @@ -346,8 +345,8 @@ class RequestAction(Message): def __init__(self, actions: list[Action] = ...) -> None: ... class ResponseAction(Message): - result: list[ActionResult] - def __init__(self, result: list[ActionResult] = ...) -> None: ... + result: list[int] + def __init__(self, result: list[int] = ...) -> None: ... class RequestObserverAction(Message): actions: list[ObserverAction] @@ -442,7 +441,7 @@ class ResponseReplayInfo(Message): data_version: str data_build: int base_build: int - error: Error + error: int error_details: str def __init__( self, @@ -455,7 +454,7 @@ class ResponseReplayInfo(Message): data_version: str = ..., data_build: int = ..., base_build: int = ..., - error: Error = ..., + error: int = ..., error_details: str = ..., ) -> None: ... @@ -476,8 +475,8 @@ class ResponseSaveMap(Message): class Error(Enum): InvalidMapData: int - error: Error - def __init__(self, error: Error = ...) -> None: ... + error: int + def __init__(self, error: int = ...) -> None: ... class RequestPing(Message): def __init__(self) -> None: ... @@ -528,18 +527,18 @@ class AIBuild(Enum): Air: int class PlayerSetup(Message): - type: PlayerType - race: Race - difficulty: Difficulty + type: int + race: int + difficulty: int player_name: str - ai_build: AIBuild + ai_build: int def __init__( self, - type: PlayerType = ..., - race: Race = ..., - difficulty: Difficulty = ..., + type: int = ..., + race: int = ..., + difficulty: int = ..., player_name: str = ..., - ai_build: AIBuild = ..., + ai_build: int = ..., ) -> None: ... class SpatialCameraSetup(Message): @@ -582,20 +581,20 @@ class InterfaceOptions(Message): class PlayerInfo(Message): player_id: int - type: PlayerType - race_requested: Race - race_actual: Race - difficulty: Difficulty - ai_build: AIBuild + type: int + race_requested: int + race_actual: int + difficulty: int + ai_build: int player_name: str def __init__( self, player_id: int = ..., - type: PlayerType = ..., - race_requested: Race = ..., - race_actual: Race = ..., - difficulty: Difficulty = ..., - ai_build: AIBuild = ..., + type: int = ..., + race_requested: int = ..., + race_actual: int = ..., + difficulty: int = ..., + ai_build: int = ..., player_name: str = ..., ) -> None: ... @@ -629,7 +628,7 @@ class PlayerCommon(Message): class Observation(Message): game_loop: int player_common: PlayerCommon - alerts: list[Alert] + alerts: list[int] abilities: list[AvailableAbility] score: Score raw_data: ObservationRaw @@ -640,7 +639,7 @@ class Observation(Message): self, game_loop: int = ..., player_common: PlayerCommon = ..., - alerts: list[Alert] = ..., + alerts: list[int] = ..., abilities: list[AvailableAbility] = ..., score: Score = ..., raw_data: ObservationRaw = ..., @@ -671,15 +670,15 @@ class Channel(Enum): Team: int class ActionChat(Message): - channel: Channel + channel: int message: str - def __init__(self, channel: Channel = ..., message: str = ...) -> None: ... + def __init__(self, channel: int = ..., message: str = ...) -> None: ... class ActionError(Message): unit_tag: int ability_id: int - result: ActionResult - def __init__(self, unit_tag: int = ..., ability_id: int = ..., result: ActionResult = ...) -> None: ... + result: int + def __init__(self, unit_tag: int = ..., ability_id: int = ..., result: int = ...) -> None: ... class ObserverAction(Message): player_perspective: ActionObserverPlayerPerspective @@ -743,5 +742,5 @@ class Result(Enum): class PlayerResult(Message): player_id: int - result: Result - def __init__(self, player_id: int = ..., result: Result = ...) -> None: ... + result: int + def __init__(self, player_id: int = ..., result: int = ...) -> None: ... diff --git a/s2clientprotocol/score_pb2.pyi b/s2clientprotocol/score_pb2.pyi index 61d54c01..01ea85e7 100644 --- a/s2clientprotocol/score_pb2.pyi +++ b/s2clientprotocol/score_pb2.pyi @@ -7,12 +7,12 @@ class ScoreType(Enum): Melee: int class Score(Message): - score_type: ScoreType + score_type: int score: int score_details: ScoreDetails def __init__( self, - score_type: ScoreType = ..., + score_type: int = ..., score: int = ..., score_details: ScoreDetails = ..., ) -> None: ... diff --git a/s2clientprotocol/spatial_pb2.pyi b/s2clientprotocol/spatial_pb2.pyi index a785470e..dfdd5762 100644 --- a/s2clientprotocol/spatial_pb2.pyi +++ b/s2clientprotocol/spatial_pb2.pyi @@ -142,8 +142,8 @@ class Type(Enum): class ActionSpatialUnitSelectionPoint(Message): selection_screen_coord: PointI - type: Type - def __init__(self, selection_screen_coord: PointI = ..., type: Type = ...) -> None: ... + type: int + def __init__(self, selection_screen_coord: PointI = ..., type: int = ...) -> None: ... class ActionSpatialUnitSelectionRect(Message): selection_screen_coord: list[RectangleI] diff --git a/s2clientprotocol/ui_pb2.pyi b/s2clientprotocol/ui_pb2.pyi index 4fb1507d..b4e1360d 100644 --- a/s2clientprotocol/ui_pb2.pyi +++ b/s2clientprotocol/ui_pb2.pyi @@ -132,9 +132,9 @@ class ControlGroupAction(Enum): AppendAndSteal: int class ActionControlGroup(Message): - action: ControlGroupAction + action: int control_group_index: int - def __init__(self, action: ControlGroupAction = ..., control_group_index: int = ...) -> None: ... + def __init__(self, action: int = ..., control_group_index: int = ...) -> None: ... class ActionSelectArmy(Message): selection_add: bool @@ -154,8 +154,8 @@ class ActionSelectIdleWorker(Message): All: int AddAll: int - type: Type - def __init__(self, type: Type = ...) -> None: ... + type: int + def __init__(self, type: int = ...) -> None: ... class ActionMultiPanel(Message): class Type(Enum): @@ -164,9 +164,9 @@ class ActionMultiPanel(Message): SelectAllOfType: int DeselectAllOfType: int - type: Type + type: int unit_index: int - def __init__(self, type: Type = ..., unit_index: int = ...) -> None: ... + def __init__(self, type: int = ..., unit_index: int = ...) -> None: ... class ActionCargoPanelUnload(Message): unit_index: int From eac3f9a7c0a6dcd181239ade7d0fb49c6750dcd5 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 14:11:16 +0100 Subject: [PATCH 14/34] Add type hints for python-sc2 related to protobuf types --- sc2/game_info.py | 3 ++- sc2/game_state.py | 21 +++++++++++++-------- sc2/player.py | 3 ++- sc2/position.py | 15 ++++++++++++--- sc2/power_source.py | 5 +++-- sc2/unit.py | 9 +++++---- sc2/units.py | 3 ++- 7 files changed, 39 insertions(+), 20 deletions(-) diff --git a/sc2/game_info.py b/sc2/game_info.py index c00a0428..4e6f1243 100644 --- a/sc2/game_info.py +++ b/sc2/game_info.py @@ -9,6 +9,7 @@ import numpy as np +from s2clientprotocol import sc2api_pb2 from sc2.pixel_map import PixelMap from sc2.player import Player from sc2.position import Point2, Rect, Size @@ -217,7 +218,7 @@ def protoss_wall_warpin(self) -> Point2 | None: class GameInfo: - def __init__(self, proto) -> None: + def __init__(self, proto: sc2api_pb2.ResponseGameInfo) -> None: self._proto = proto self.players: list[Player] = [Player.from_proto(p) for p in self._proto.player_info] self.map_name: str = self._proto.map_name diff --git a/sc2/game_state.py b/sc2/game_state.py index b17fd12e..e4b7d17b 100644 --- a/sc2/game_state.py +++ b/sc2/game_state.py @@ -7,6 +7,7 @@ from loguru import logger +from s2clientprotocol import raw_pb2, sc2api_pb2 from sc2.constants import IS_ENEMY, IS_MINE, FakeEffectID, FakeEffectRadii from sc2.data import Alliance, DisplayType from sc2.ids.ability_id import AbilityId @@ -25,7 +26,7 @@ class Blip: - def __init__(self, proto) -> None: + def __init__(self, proto: raw_pb2.Unit) -> None: """ :param proto: """ @@ -91,7 +92,7 @@ def __getattr__(self, attr) -> int: class EffectData: - def __init__(self, proto, fake: bool = False) -> None: + def __init__(self, proto: raw_pb2.Effect | raw_pb2.Unit, fake: bool = False) -> None: """ :param proto: :param fake: @@ -101,20 +102,20 @@ def __init__(self, proto, fake: bool = False) -> None: @property def id(self) -> EffectId | str: - if self.fake: + if isinstance(self._proto, raw_pb2.Unit): # Returns the string from constants.py, e.g. "KD8CHARGE" return FakeEffectID[self._proto.unit_type] return EffectId(self._proto.effect_id) @property def positions(self) -> set[Point2]: - if self.fake: + if isinstance(self._proto, raw_pb2.Unit): return {Point2.from_proto(self._proto.pos)} return {Point2.from_proto(p) for p in self._proto.pos} @property def alliance(self) -> Alliance: - return self._proto.alliance + return Alliance(self._proto.alliance) @property def is_mine(self) -> bool: @@ -191,7 +192,11 @@ class ActionError(AbilityLookupTemplateClass): class GameState: - def __init__(self, response_observation, previous_observation=None) -> None: + def __init__( + self, + response_observation: sc2api_pb2.ResponseObservation, + previous_observation: sc2api_pb2.ResponseObservation | None = None, + ) -> None: """ :param response_observation: :param previous_observation: @@ -252,7 +257,7 @@ def alerts(self) -> list[int]: """ Game alerts, see https://github.com/Blizzard/s2client-proto/blob/01ab351e21c786648e4c6693d4aad023a176d45c/s2clientprotocol/sc2api.proto#L683-L706 """ - if self.previous_observation: + if self.previous_observation is not None: return list(chain(self.previous_observation.observation.alerts, self.observation.alerts)) return self.observation.alerts @@ -265,7 +270,7 @@ def actions(self) -> list[ActionRawUnitCommand | ActionRawToggleAutocast | Actio Each action is converted into Python dataclasses: ActionRawUnitCommand, ActionRawToggleAutocast, ActionRawCameraMove """ previous_frame_actions = self.previous_observation.actions if self.previous_observation else [] - actions = [] + actions: list[ActionRawUnitCommand | ActionRawToggleAutocast | ActionRawCameraMove] = [] for action in chain(previous_frame_actions, self.response_observation.actions): action_raw = action.action_raw game_loop = action.game_loop diff --git a/sc2/player.py b/sc2/player.py index 74ee5463..935535c0 100644 --- a/sc2/player.py +++ b/sc2/player.py @@ -4,6 +4,7 @@ from abc import ABC from pathlib import Path +from s2clientprotocol import sc2api_pb2 from sc2.bot_ai import BotAI from sc2.data import AIBuild, Difficulty, PlayerType, Race @@ -107,7 +108,7 @@ def __init__( self.actual_race: Race = actual_race @classmethod - def from_proto(cls, proto) -> Player: + def from_proto(cls, proto: sc2api_pb2.PlayerInfo) -> Player: if PlayerType(proto.type) == PlayerType.Observer: return cls(proto.player_id, PlayerType(proto.type), None, None, None) return cls( diff --git a/sc2/position.py b/sc2/position.py index 3e37ad2f..7f7da4c4 100644 --- a/sc2/position.py +++ b/sc2/position.py @@ -142,7 +142,9 @@ def __hash__(self) -> int: class Point2(Pointlike): @classmethod - def from_proto(cls, data: common_pb.Point2D) -> Point2: + def from_proto( + cls, data: common_pb.Point | common_pb.Point2D | common_pb.Size2DI | common_pb.PointI | Point2 | Point3 + ) -> Point2: """ :param data: """ @@ -324,7 +326,7 @@ def center(points: list[Point2]) -> Point2: class Point3(Point2): @classmethod - def from_proto(cls, data: common_pb.Point) -> Point3: + def from_proto(cls, data: common_pb.Point | Point3) -> Point3: """ :param data: """ @@ -355,6 +357,13 @@ def __add__(self, other: Point2 | Point3) -> Point3: class Size(Point2): + @classmethod + def from_proto(cls, data: common_pb.Point | common_pb.Point2D | common_pb.Size2DI | common_pb.PointI) -> Size: + """ + :param data: + """ + return cls((data.x, data.y)) + @property def width(self) -> float: return self[0] @@ -366,7 +375,7 @@ def height(self) -> float: class Rect(tuple): @classmethod - def from_proto(cls, data) -> Rect: + def from_proto(cls, data: common_pb.RectangleI) -> Rect: """ :param data: """ diff --git a/sc2/power_source.py b/sc2/power_source.py index 8c64bb62..07814833 100644 --- a/sc2/power_source.py +++ b/sc2/power_source.py @@ -2,6 +2,7 @@ from dataclasses import dataclass +from s2clientprotocol import raw_pb2 from sc2.position import Point2 @@ -15,7 +16,7 @@ def __post_init__(self) -> None: assert self.radius > 0 @classmethod - def from_proto(cls, proto) -> PowerSource: + def from_proto(cls, proto: raw_pb2.PowerSource) -> PowerSource: return PowerSource(Point2.from_proto(proto.pos), proto.radius, proto.tag) def covers(self, position: Point2) -> bool: @@ -30,7 +31,7 @@ class PsionicMatrix: sources: list[PowerSource] @classmethod - def from_proto(cls, proto) -> PsionicMatrix: + def from_proto(cls, proto: list[raw_pb2.PowerSource]) -> PsionicMatrix: return PsionicMatrix([PowerSource.from_proto(p) for p in proto]) def covers(self, position: Point2) -> bool: diff --git a/sc2/unit.py b/sc2/unit.py index 07b63e90..05949505 100644 --- a/sc2/unit.py +++ b/sc2/unit.py @@ -7,6 +7,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any +from s2clientprotocol import raw_pb2 from sc2.cache import CacheDict from sc2.constants import ( CAN_BE_ATTACKED, @@ -71,7 +72,7 @@ class RallyTarget: tag: int | None = None @classmethod - def from_proto(cls, proto: Any) -> RallyTarget: + def from_proto(cls, proto: raw_pb2.RallyTarget) -> RallyTarget: return cls( Point2.from_proto(proto.point), proto.tag if proto.HasField("tag") else None, @@ -85,7 +86,7 @@ class UnitOrder: progress: float = 0 @classmethod - def from_proto(cls, proto: Any, bot_object: BotAI) -> UnitOrder: + def from_proto(cls, proto: raw_pb2.UnitOrder, bot_object: BotAI) -> UnitOrder: target: int | Point2 | None = proto.target_unit_tag if proto.HasField("target_world_space_pos"): target = Point2.from_proto(proto.target_world_space_pos) @@ -106,7 +107,7 @@ class Unit: def __init__( self, - proto_data, + proto_data: raw_pb2.Unit, bot_object: BotAI, distance_calculation_index: int = -1, base_build: int = -1, @@ -1034,7 +1035,7 @@ def order_target(self) -> int | Point2 | None: from the first order, returns None if the unit is idle""" if self.orders: target = self.orders[0].target - if isinstance(target, int): + if target is None or isinstance(target, int): return target return Point2.from_proto(target) return None diff --git a/sc2/units.py b/sc2/units.py index 38813601..c41d8ed7 100644 --- a/sc2/units.py +++ b/sc2/units.py @@ -6,6 +6,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any +from s2clientprotocol import raw_pb2 from sc2.ids.unit_typeid import UnitTypeId from sc2.position import Point2 from sc2.unit import Unit @@ -18,7 +19,7 @@ class Units(list): """A collection of Unit objects. Makes it easy to select units by selectors.""" @classmethod - def from_proto(cls, units, bot_object: BotAI) -> Units: + def from_proto(cls, units: list[raw_pb2.Unit], bot_object: BotAI) -> Units: return cls((Unit(raw_unit, bot_object=bot_object) for raw_unit in units), bot_object) def __init__(self, units: Iterable[Unit], bot_object: BotAI) -> None: From 4346c19413cb58779324de55c5819a3621a52780 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Sat, 15 Nov 2025 14:12:50 +0100 Subject: [PATCH 15/34] Apply autoformat --- s2clientprotocol/common_pb2.pyi | 1 + s2clientprotocol/data_pb2.pyi | 1 + s2clientprotocol/debug_pb2.pyi | 2 ++ s2clientprotocol/query_pb2.pyi | 3 ++- s2clientprotocol/raw_pb2.pyi | 4 +++- s2clientprotocol/sc2api_pb2.pyi | 9 +++++---- s2clientprotocol/score_pb2.pyi | 2 ++ s2clientprotocol/spatial_pb2.pyi | 2 ++ s2clientprotocol/ui_pb2.pyi | 2 ++ sc2/bot_ai_internal.py | 1 - sc2/client.py | 1 - sc2/data.py | 1 - test/generate_pickle_files_bot.py | 1 - 13 files changed, 20 insertions(+), 10 deletions(-) diff --git a/s2clientprotocol/common_pb2.pyi b/s2clientprotocol/common_pb2.pyi index 2a4028c7..e586cfef 100644 --- a/s2clientprotocol/common_pb2.pyi +++ b/s2clientprotocol/common_pb2.pyi @@ -1,5 +1,6 @@ # https://github.com/Blizzard/s2client-proto/blob/bff45dae1fc685e6acbaae084670afb7d1c0832c/s2clientprotocol/common.proto from enum import Enum + from google.protobuf.message import Message class AvailableAbility(Message): diff --git a/s2clientprotocol/data_pb2.pyi b/s2clientprotocol/data_pb2.pyi index 42d9410a..aff16112 100644 --- a/s2clientprotocol/data_pb2.pyi +++ b/s2clientprotocol/data_pb2.pyi @@ -1,4 +1,5 @@ from enum import Enum + from google.protobuf.message import Message class Target(Enum): diff --git a/s2clientprotocol/debug_pb2.pyi b/s2clientprotocol/debug_pb2.pyi index ef8de239..fe710158 100644 --- a/s2clientprotocol/debug_pb2.pyi +++ b/s2clientprotocol/debug_pb2.pyi @@ -1,5 +1,7 @@ from enum import Enum + from google.protobuf.message import Message + from .common_pb2 import Point, Point2D class DebugCommand(Message): diff --git a/s2clientprotocol/query_pb2.pyi b/s2clientprotocol/query_pb2.pyi index 047fcd46..2ba65f0a 100644 --- a/s2clientprotocol/query_pb2.pyi +++ b/s2clientprotocol/query_pb2.pyi @@ -1,5 +1,6 @@ from google.protobuf.message import Message -from .common_pb2 import Point2D, AvailableAbility + +from .common_pb2 import AvailableAbility, Point2D class RequestQuery(Message): pathing: list[RequestQueryPathing] diff --git a/s2clientprotocol/raw_pb2.pyi b/s2clientprotocol/raw_pb2.pyi index 0a31096f..c41ab731 100644 --- a/s2clientprotocol/raw_pb2.pyi +++ b/s2clientprotocol/raw_pb2.pyi @@ -1,6 +1,8 @@ from enum import Enum + from google.protobuf.message import Message -from .common_pb2 import Point2D, Point, Size2DI, ImageData, RectangleI + +from .common_pb2 import ImageData, Point, Point2D, RectangleI, Size2DI class StartRaw(Message): map_size: Size2DI diff --git a/s2clientprotocol/sc2api_pb2.pyi b/s2clientprotocol/sc2api_pb2.pyi index d1f7a552..23ad2d15 100644 --- a/s2clientprotocol/sc2api_pb2.pyi +++ b/s2clientprotocol/sc2api_pb2.pyi @@ -1,17 +1,18 @@ from __future__ import annotations from enum import Enum + from google.protobuf.message import Message from s2clientprotocol.spatial_pb2 import ActionSpatial, ObservationFeatureLayer, ObservationRender -from .common_pb2 import Point2D, Size2DI, AvailableAbility -from .data_pb2 import AbilityData, UnitTypeData, UpgradeData, BuffData, EffectData + +from .common_pb2 import AvailableAbility, Point2D, Size2DI +from .data_pb2 import AbilityData, BuffData, EffectData, UnitTypeData, UpgradeData from .debug_pb2 import DebugCommand from .query_pb2 import RequestQuery, ResponseQuery from .raw_pb2 import ActionRaw, ObservationRaw, StartRaw -from .ui_pb2 import ActionUI, ObservationUI - from .score_pb2 import Score +from .ui_pb2 import ActionUI, ObservationUI class Request(Message): create_game: RequestCreateGame diff --git a/s2clientprotocol/score_pb2.pyi b/s2clientprotocol/score_pb2.pyi index 01ea85e7..88c47391 100644 --- a/s2clientprotocol/score_pb2.pyi +++ b/s2clientprotocol/score_pb2.pyi @@ -1,5 +1,7 @@ from __future__ import annotations + from enum import Enum + from google.protobuf.message import Message class ScoreType(Enum): diff --git a/s2clientprotocol/spatial_pb2.pyi b/s2clientprotocol/spatial_pb2.pyi index dfdd5762..0fad99e6 100644 --- a/s2clientprotocol/spatial_pb2.pyi +++ b/s2clientprotocol/spatial_pb2.pyi @@ -1,7 +1,9 @@ from __future__ import annotations from enum import Enum + from google.protobuf.message import Message + from .common_pb2 import ImageData, PointI, RectangleI class ObservationFeatureLayer(Message): diff --git a/s2clientprotocol/ui_pb2.pyi b/s2clientprotocol/ui_pb2.pyi index b4e1360d..236bb70c 100644 --- a/s2clientprotocol/ui_pb2.pyi +++ b/s2clientprotocol/ui_pb2.pyi @@ -1,5 +1,7 @@ from __future__ import annotations + from enum import Enum + from google.protobuf.message import Message class ObservationUI(Message): diff --git a/sc2/bot_ai_internal.py b/sc2/bot_ai_internal.py index bb5738c2..061ef70f 100644 --- a/sc2/bot_ai_internal.py +++ b/sc2/bot_ai_internal.py @@ -16,7 +16,6 @@ # pyre-ignore[21] from s2clientprotocol import sc2api_pb2 as sc_pb - from sc2.cache import property_cache_once_per_frame from sc2.constants import ( ALL_GAS, diff --git a/sc2/client.py b/sc2/client.py index 5888b190..2ad47ca2 100644 --- a/sc2/client.py +++ b/sc2/client.py @@ -12,7 +12,6 @@ from s2clientprotocol import raw_pb2 as raw_pb from s2clientprotocol import sc2api_pb2 as sc_pb from s2clientprotocol import spatial_pb2 as spatial_pb - from sc2.action import combine_actions from sc2.data import ActionResult, ChatChannel, Race, Result, Status from sc2.game_data import AbilityData, GameData diff --git a/sc2/data.py b/sc2/data.py index a4377b8a..ce271edc 100644 --- a/sc2/data.py +++ b/sc2/data.py @@ -14,7 +14,6 @@ from s2clientprotocol import error_pb2 as error_pb from s2clientprotocol import raw_pb2 as raw_pb from s2clientprotocol import sc2api_pb2 as sc_pb - from sc2.ids.ability_id import AbilityId from sc2.ids.unit_typeid import UnitTypeId diff --git a/test/generate_pickle_files_bot.py b/test/generate_pickle_files_bot.py index ff12aa29..ae8d46b8 100644 --- a/test/generate_pickle_files_bot.py +++ b/test/generate_pickle_files_bot.py @@ -12,7 +12,6 @@ # pyre-ignore[21] from s2clientprotocol import sc2api_pb2 as sc_pb - from sc2 import maps from sc2.bot_ai import BotAI from sc2.data import Difficulty, Race From f207373d72c41793a90e5e8706690b1f13eb0085 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Mon, 17 Nov 2025 16:59:33 +0100 Subject: [PATCH 16/34] Add type hints for various files --- s2clientprotocol/data_pb2.pyi | 4 +- sc2/client.py | 2 +- sc2/data.py | 1 - sc2/data.pyi | 633 +++++++++++++++++----------------- sc2/game_data.py | 19 +- sc2/py.typed | 1 + sc2/sc2process.py | 2 - sc2/score.py | 7 +- sc2/unit.py | 14 +- sc2/units.py | 2 +- test/autotest_bot.py | 2 +- 11 files changed, 343 insertions(+), 344 deletions(-) diff --git a/s2clientprotocol/data_pb2.pyi b/s2clientprotocol/data_pb2.pyi index aff16112..4648b530 100644 --- a/s2clientprotocol/data_pb2.pyi +++ b/s2clientprotocol/data_pb2.pyi @@ -103,7 +103,7 @@ class UnitTypeData(Message): unit_alias: int tech_requirement: int require_attached: bool - attributes: list[Attribute] + attributes: list[int] movement_speed: float armor: float weapons: list[Weapon] @@ -127,7 +127,7 @@ class UnitTypeData(Message): unit_alias: int = ..., tech_requirement: int = ..., require_attached: bool = ..., - attributes: list[Attribute] = ..., + attributes: list[int] = ..., movement_speed: float = ..., armor: float = ..., weapons: list[Weapon] = ..., diff --git a/sc2/client.py b/sc2/client.py index 2ad47ca2..dd5fedad 100644 --- a/sc2/client.py +++ b/sc2/client.py @@ -157,7 +157,7 @@ async def step(self, step_size: int = None): return await self._execute(step=sc_pb.RequestStep(count=step_size)) async def get_game_data(self) -> GameData: - result = await self._execute( + result: sc_pb.ResponseData = await self._execute( data=sc_pb.RequestData(ability_id=True, unit_type_id=True, upgrade_id=True, buff_id=True, effect_id=True) ) return GameData(result.data) diff --git a/sc2/data.py b/sc2/data.py index ce271edc..d376138b 100644 --- a/sc2/data.py +++ b/sc2/data.py @@ -8,7 +8,6 @@ import enum -# pyre-ignore[21] from s2clientprotocol import common_pb2 as common_pb from s2clientprotocol import data_pb2 as data_pb from s2clientprotocol import error_pb2 as error_pb diff --git a/sc2/data.pyi b/sc2/data.pyi index b625d769..57ad0c1a 100644 --- a/sc2/data.pyi +++ b/sc2/data.pyi @@ -10,92 +10,92 @@ and mypy to understand the structure and members of these enums. from __future__ import annotations -from enum import Enum +from enum import Enum, IntEnum from sc2.ids.ability_id import AbilityId from sc2.ids.unit_typeid import UnitTypeId # Enums created from sc2api_pb2 class CreateGameError(Enum): - MissingMap: int - InvalidMapPath: int - InvalidMapData: int - InvalidMapName: int - InvalidMapHandle: int - MissingPlayerSetup: int - InvalidPlayerSetup: int - MultiplayerUnsupported: int + MissingMap = 1 + InvalidMapPath = 2 + InvalidMapData = 3 + InvalidMapName = 4 + InvalidMapHandle = 5 + MissingPlayerSetup = 6 + InvalidPlayerSetup = 7 + MultiplayerUnsupported = 8 -class PlayerType(Enum): - Participant: int - Computer: int - Observer: int +class PlayerType(IntEnum): + Participant = 1 + Computer = 2 + Observer = 3 class Difficulty(Enum): - VeryEasy: int - Easy: int - Medium: int - MediumHard: int - Hard: int - Harder: int - VeryHard: int - CheatVision: int - CheatMoney: int - CheatInsane: int + VeryEasy = 1 + Easy = 2 + Medium = 3 + MediumHard = 4 + Hard = 5 + Harder = 6 + VeryHard = 7 + CheatVision = 8 + CheatMoney = 9 + CheatInsane = 10 class AIBuild(Enum): - RandomBuild: int - Rush: int - Timing: int - Power: int - Macro: int - Air: int + RandomBuild = 1 + Rush = 2 + Timing = 3 + Power = 4 + Macro = 5 + Air = 6 class Status(Enum): - launched: int - init_game: int - in_game: int - in_replay: int - ended: int - quit: int - unknown: int + launched = 1 + init_game = 2 + in_game = 3 + in_replay = 4 + ended = 5 + quit = 6 + unknown = 7 class Result(Enum): - Victory: int - Defeat: int - Tie: int - Undecided: int + Victory = 1 + Defeat = 2 + Tie = 3 + Undecided = 4 class Alert(Enum): - AlertError: int - AddOnComplete: int - BuildingComplete: int - BuildingUnderAttack: int - LarvaHatched: int - MergeComplete: int - MineralsExhausted: int - MorphComplete: int - MothershipComplete: int - MULEExpired: int - NuclearLaunchDetected: int - NukeComplete: int - NydusWormDetected: int - ResearchComplete: int - TrainError: int - TrainUnitComplete: int - TrainWorkerComplete: int - TransformationComplete: int - UnitUnderAttack: int - UpgradeComplete: int - VespeneExhausted: int - WarpInComplete: int + AlertError = 1 + AddOnComplete = 2 + BuildingComplete = 3 + BuildingUnderAttack = 4 + LarvaHatched = 5 + MergeComplete = 6 + MineralsExhausted = 7 + MorphComplete = 8 + MothershipComplete = 9 + MULEExpired = 10 + NuclearLaunchDetected = 11 + NukeComplete = 12 + NydusWormDetected = 13 + ResearchComplete = 14 + TrainError = 15 + TrainUnitComplete = 16 + TrainWorkerComplete = 17 + TransformationComplete = 18 + UnitUnderAttack = 19 + UpgradeComplete = 20 + VespeneExhausted = 21 + WarpInComplete = 22 class ChatChannel(Enum): - Broadcast: int - Team: int + Broadcast = 1 + Team = 2 # Enums created from common_pb2 -class Race(Enum): +class Race(IntEnum): """StarCraft II race enum. Members: @@ -106,60 +106,59 @@ class Race(Enum): Random: Random race selection """ - NoRace: int - Terran: int - Zerg: int - Protoss: int - Random: int + NoRace = 0 + Terran = 1 + Zerg = 2 + Protoss = 3 + Random = 4 # Enums created from raw_pb2 class DisplayType(Enum): - Visible: int - Snapshot: int - Hidden: int - Placeholder: int + Visible = 1 + Snapshot = 2 + Hidden = 3 + Placeholder = 4 class Alliance(Enum): - Self: int - Ally: int - Neutral: int - Enemy: int + Self = 1 + Ally = 2 + Neutral = 3 + Enemy = 4 class CloakState(Enum): - CloakedUnknown: int - Cloaked: int - CloakedDetected: int - NotCloaked: int - CloakedAllied: int + CloakedUnknown = 1 + Cloaked = 2 + CloakedDetected = 3 + NotCloaked = 4 + CloakedAllied = 5 # Enums created from data_pb2 class Attribute(Enum): - Light: int - Armored: int - Biological: int - Mechanical: int - Robotic: int - Psionic: int - Massive: int - Structure: int - Hover: int - Heroic: int - Summoned: int - Invalid: int + Light = 1 + Armored = 2 + Biological = 3 + Mechanical = 4 + Robotic = 5 + Psionic = 6 + Massive = 7 + Structure = 8 + Hover = 9 + Heroic = 10 + Summoned = 11 class TargetType(Enum): - Ground: int - Air: int - Any: int - Invalid: int + Ground = 1 + Air = 2 + Any = 3 + Invalid = 4 class Target(Enum): # Note: The protobuf enum member 'None' is a Python keyword, # so at runtime it may need special handling - Point: int - Unit: int - PointOrUnit: int - PointOrNone: int + Point = 1 + Unit = 2 + PointOrUnit = 3 + PointOrNone = 4 # Enums created from error_pb2 class ActionResult(Enum): @@ -169,220 +168,220 @@ class ActionResult(Enum): various action results and error conditions. """ - Success: int - NotSupported: int - Error: int - CantQueueThatOrder: int - Retry: int - Cooldown: int - QueueIsFull: int - RallyQueueIsFull: int - NotEnoughMinerals: int - NotEnoughVespene: int - NotEnoughTerrazine: int - NotEnoughCustom: int - NotEnoughFood: int - FoodUsageImpossible: int - NotEnoughLife: int - NotEnoughShields: int - NotEnoughEnergy: int - LifeSuppressed: int - ShieldsSuppressed: int - EnergySuppressed: int - NotEnoughCharges: int - CantAddMoreCharges: int - TooMuchMinerals: int - TooMuchVespene: int - TooMuchTerrazine: int - TooMuchCustom: int - TooMuchFood: int - TooMuchLife: int - TooMuchShields: int - TooMuchEnergy: int - MustTargetUnitWithLife: int - MustTargetUnitWithShields: int - MustTargetUnitWithEnergy: int - CantTrade: int - CantSpend: int - CantTargetThatUnit: int - CouldntAllocateUnit: int - UnitCantMove: int - TransportIsHoldingPosition: int - BuildTechRequirementsNotMet: int - CantFindPlacementLocation: int - CantBuildOnThat: int - CantBuildTooCloseToDropOff: int - CantBuildLocationInvalid: int - CantSeeBuildLocation: int - CantBuildTooCloseToCreepSource: int - CantBuildTooCloseToResources: int - CantBuildTooFarFromWater: int - CantBuildTooFarFromCreepSource: int - CantBuildTooFarFromBuildPowerSource: int - CantBuildOnDenseTerrain: int - CantTrainTooFarFromTrainPowerSource: int - CantLandLocationInvalid: int - CantSeeLandLocation: int - CantLandTooCloseToCreepSource: int - CantLandTooCloseToResources: int - CantLandTooFarFromWater: int - CantLandTooFarFromCreepSource: int - CantLandTooFarFromBuildPowerSource: int - CantLandTooFarFromTrainPowerSource: int - CantLandOnDenseTerrain: int - AddOnTooFarFromBuilding: int - MustBuildRefineryFirst: int - BuildingIsUnderConstruction: int - CantFindDropOff: int - CantLoadOtherPlayersUnits: int - NotEnoughRoomToLoadUnit: int - CantUnloadUnitsThere: int - CantWarpInUnitsThere: int - CantLoadImmobileUnits: int - CantRechargeImmobileUnits: int - CantRechargeUnderConstructionUnits: int - CantLoadThatUnit: int - NoCargoToUnload: int - LoadAllNoTargetsFound: int - NotWhileOccupied: int - CantAttackWithoutAmmo: int - CantHoldAnyMoreAmmo: int - TechRequirementsNotMet: int - MustLockdownUnitFirst: int - MustTargetUnit: int - MustTargetInventory: int - MustTargetVisibleUnit: int - MustTargetVisibleLocation: int - MustTargetWalkableLocation: int - MustTargetPawnableUnit: int - YouCantControlThatUnit: int - YouCantIssueCommandsToThatUnit: int - MustTargetResources: int - RequiresHealTarget: int - RequiresRepairTarget: int - NoItemsToDrop: int - CantHoldAnyMoreItems: int - CantHoldThat: int - TargetHasNoInventory: int - CantDropThisItem: int - CantMoveThisItem: int - CantPawnThisUnit: int - MustTargetCaster: int - CantTargetCaster: int - MustTargetOuter: int - CantTargetOuter: int - MustTargetYourOwnUnits: int - CantTargetYourOwnUnits: int - MustTargetFriendlyUnits: int - CantTargetFriendlyUnits: int - MustTargetNeutralUnits: int - CantTargetNeutralUnits: int - MustTargetEnemyUnits: int - CantTargetEnemyUnits: int - MustTargetAirUnits: int - CantTargetAirUnits: int - MustTargetGroundUnits: int - CantTargetGroundUnits: int - MustTargetStructures: int - CantTargetStructures: int - MustTargetLightUnits: int - CantTargetLightUnits: int - MustTargetArmoredUnits: int - CantTargetArmoredUnits: int - MustTargetBiologicalUnits: int - CantTargetBiologicalUnits: int - MustTargetHeroicUnits: int - CantTargetHeroicUnits: int - MustTargetRoboticUnits: int - CantTargetRoboticUnits: int - MustTargetMechanicalUnits: int - CantTargetMechanicalUnits: int - MustTargetPsionicUnits: int - CantTargetPsionicUnits: int - MustTargetMassiveUnits: int - CantTargetMassiveUnits: int - MustTargetMissile: int - CantTargetMissile: int - MustTargetWorkerUnits: int - CantTargetWorkerUnits: int - MustTargetEnergyCapableUnits: int - CantTargetEnergyCapableUnits: int - MustTargetShieldCapableUnits: int - CantTargetShieldCapableUnits: int - MustTargetFlyers: int - CantTargetFlyers: int - MustTargetBuriedUnits: int - CantTargetBuriedUnits: int - MustTargetCloakedUnits: int - CantTargetCloakedUnits: int - MustTargetUnitsInAStasisField: int - CantTargetUnitsInAStasisField: int - MustTargetUnderConstructionUnits: int - CantTargetUnderConstructionUnits: int - MustTargetDeadUnits: int - CantTargetDeadUnits: int - MustTargetRevivableUnits: int - CantTargetRevivableUnits: int - MustTargetHiddenUnits: int - CantTargetHiddenUnits: int - CantRechargeOtherPlayersUnits: int - MustTargetHallucinations: int - CantTargetHallucinations: int - MustTargetInvulnerableUnits: int - CantTargetInvulnerableUnits: int - MustTargetDetectedUnits: int - CantTargetDetectedUnits: int - CantTargetUnitWithEnergy: int - CantTargetUnitWithShields: int - MustTargetUncommandableUnits: int - CantTargetUncommandableUnits: int - MustTargetPreventDefeatUnits: int - CantTargetPreventDefeatUnits: int - MustTargetPreventRevealUnits: int - CantTargetPreventRevealUnits: int - MustTargetPassiveUnits: int - CantTargetPassiveUnits: int - MustTargetStunnedUnits: int - CantTargetStunnedUnits: int - MustTargetSummonedUnits: int - CantTargetSummonedUnits: int - MustTargetUser1: int - CantTargetUser1: int - MustTargetUnstoppableUnits: int - CantTargetUnstoppableUnits: int - MustTargetResistantUnits: int - CantTargetResistantUnits: int - MustTargetDazedUnits: int - CantTargetDazedUnits: int - CantLockdown: int - CantMindControl: int - MustTargetDestructibles: int - CantTargetDestructibles: int - MustTargetItems: int - CantTargetItems: int - NoCalldownAvailable: int - WaypointListFull: int - MustTargetRace: int - CantTargetRace: int - MustTargetSimilarUnits: int - CantTargetSimilarUnits: int - CantFindEnoughTargets: int - AlreadySpawningLarva: int - CantTargetExhaustedResources: int - CantUseMinimap: int - CantUseInfoPanel: int - OrderQueueIsFull: int - CantHarvestThatResource: int - HarvestersNotRequired: int - AlreadyTargeted: int - CantAttackWeaponsDisabled: int - CouldntReachTarget: int - TargetIsOutOfRange: int - TargetIsTooClose: int - TargetIsOutOfArc: int - CantFindTeleportLocation: int - InvalidItemClass: int - CantFindCancelOrder: int + Success = 1 + NotSupported = 2 + Error = 3 + CantQueueThatOrder = 4 + Retry = 5 + Cooldown = 6 + QueueIsFull = 7 + RallyQueueIsFull = 8 + NotEnoughMinerals = 9 + NotEnoughVespene = 10 + NotEnoughTerrazine = 11 + NotEnoughCustom = 12 + NotEnoughFood = 13 + FoodUsageImpossible = 14 + NotEnoughLife = 15 + NotEnoughShields = 16 + NotEnoughEnergy = 17 + LifeSuppressed = 18 + ShieldsSuppressed = 19 + EnergySuppressed = 20 + NotEnoughCharges = 21 + CantAddMoreCharges = 22 + TooMuchMinerals = 23 + TooMuchVespene = 24 + TooMuchTerrazine = 25 + TooMuchCustom = 26 + TooMuchFood = 27 + TooMuchLife = 28 + TooMuchShields = 29 + TooMuchEnergy = 30 + MustTargetUnitWithLife = 31 + MustTargetUnitWithShields = 32 + MustTargetUnitWithEnergy = 33 + CantTrade = 34 + CantSpend = 35 + CantTargetThatUnit = 36 + CouldntAllocateUnit = 37 + UnitCantMove = 38 + TransportIsHoldingPosition = 39 + BuildTechRequirementsNotMet = 40 + CantFindPlacementLocation = 41 + CantBuildOnThat = 42 + CantBuildTooCloseToDropOff = 43 + CantBuildLocationInvalid = 44 + CantSeeBuildLocation = 45 + CantBuildTooCloseToCreepSource = 46 + CantBuildTooCloseToResources = 47 + CantBuildTooFarFromWater = 48 + CantBuildTooFarFromCreepSource = 49 + CantBuildTooFarFromBuildPowerSource = 50 + CantBuildOnDenseTerrain = 51 + CantTrainTooFarFromTrainPowerSource = 52 + CantLandLocationInvalid = 53 + CantSeeLandLocation = 54 + CantLandTooCloseToCreepSource = 55 + CantLandTooCloseToResources = 56 + CantLandTooFarFromWater = 57 + CantLandTooFarFromCreepSource = 58 + CantLandTooFarFromBuildPowerSource = 59 + CantLandTooFarFromTrainPowerSource = 60 + CantLandOnDenseTerrain = 61 + AddOnTooFarFromBuilding = 62 + MustBuildRefineryFirst = 63 + BuildingIsUnderConstruction = 64 + CantFindDropOff = 65 + CantLoadOtherPlayersUnits = 66 + NotEnoughRoomToLoadUnit = 67 + CantUnloadUnitsThere = 68 + CantWarpInUnitsThere = 69 + CantLoadImmobileUnits = 70 + CantRechargeImmobileUnits = 71 + CantRechargeUnderConstructionUnits = 72 + CantLoadThatUnit = 73 + NoCargoToUnload = 74 + LoadAllNoTargetsFound = 75 + NotWhileOccupied = 76 + CantAttackWithoutAmmo = 77 + CantHoldAnyMoreAmmo = 78 + TechRequirementsNotMet = 79 + MustLockdownUnitFirst = 80 + MustTargetUnit = 81 + MustTargetInventory = 82 + MustTargetVisibleUnit = 83 + MustTargetVisibleLocation = 84 + MustTargetWalkableLocation = 85 + MustTargetPawnableUnit = 86 + YouCantControlThatUnit = 87 + YouCantIssueCommandsToThatUnit = 88 + MustTargetResources = 89 + RequiresHealTarget = 90 + RequiresRepairTarget = 91 + NoItemsToDrop = 92 + CantHoldAnyMoreItems = 93 + CantHoldThat = 94 + TargetHasNoInventory = 95 + CantDropThisItem = 96 + CantMoveThisItem = 97 + CantPawnThisUnit = 98 + MustTargetCaster = 99 + CantTargetCaster = 100 + MustTargetOuter = 101 + CantTargetOuter = 102 + MustTargetYourOwnUnits = 103 + CantTargetYourOwnUnits = 104 + MustTargetFriendlyUnits = 105 + CantTargetFriendlyUnits = 106 + MustTargetNeutralUnits = 107 + CantTargetNeutralUnits = 108 + MustTargetEnemyUnits = 109 + CantTargetEnemyUnits = 110 + MustTargetAirUnits = 111 + CantTargetAirUnits = 112 + MustTargetGroundUnits = 113 + CantTargetGroundUnits = 114 + MustTargetStructures = 115 + CantTargetStructures = 116 + MustTargetLightUnits = 117 + CantTargetLightUnits = 118 + MustTargetArmoredUnits = 119 + CantTargetArmoredUnits = 120 + MustTargetBiologicalUnits = 121 + CantTargetBiologicalUnits = 122 + MustTargetHeroicUnits = 123 + CantTargetHeroicUnits = 124 + MustTargetRoboticUnits = 125 + CantTargetRoboticUnits = 126 + MustTargetMechanicalUnits = 127 + CantTargetMechanicalUnits = 128 + MustTargetPsionicUnits = 129 + CantTargetPsionicUnits = 130 + MustTargetMassiveUnits = 131 + CantTargetMassiveUnits = 132 + MustTargetMissile = 133 + CantTargetMissile = 134 + MustTargetWorkerUnits = 135 + CantTargetWorkerUnits = 136 + MustTargetEnergyCapableUnits = 137 + CantTargetEnergyCapableUnits = 138 + MustTargetShieldCapableUnits = 139 + CantTargetShieldCapableUnits = 140 + MustTargetFlyers = 141 + CantTargetFlyers = 142 + MustTargetBuriedUnits = 143 + CantTargetBuriedUnits = 144 + MustTargetCloakedUnits = 145 + CantTargetCloakedUnits = 146 + MustTargetUnitsInAStasisField = 147 + CantTargetUnitsInAStasisField = 148 + MustTargetUnderConstructionUnits = 149 + CantTargetUnderConstructionUnits = 150 + MustTargetDeadUnits = 151 + CantTargetDeadUnits = 152 + MustTargetRevivableUnits = 153 + CantTargetRevivableUnits = 154 + MustTargetHiddenUnits = 155 + CantTargetHiddenUnits = 156 + CantRechargeOtherPlayersUnits = 157 + MustTargetHallucinations = 158 + CantTargetHallucinations = 159 + MustTargetInvulnerableUnits = 160 + CantTargetInvulnerableUnits = 161 + MustTargetDetectedUnits = 162 + CantTargetDetectedUnits = 163 + CantTargetUnitWithEnergy = 164 + CantTargetUnitWithShields = 165 + MustTargetUncommandableUnits = 166 + CantTargetUncommandableUnits = 167 + MustTargetPreventDefeatUnits = 168 + CantTargetPreventDefeatUnits = 169 + MustTargetPreventRevealUnits = 170 + CantTargetPreventRevealUnits = 171 + MustTargetPassiveUnits = 172 + CantTargetPassiveUnits = 173 + MustTargetStunnedUnits = 174 + CantTargetStunnedUnits = 175 + MustTargetSummonedUnits = 176 + CantTargetSummonedUnits = 177 + MustTargetUser1 = 178 + CantTargetUser1 = 179 + MustTargetUnstoppableUnits = 180 + CantTargetUnstoppableUnits = 181 + MustTargetResistantUnits = 182 + CantTargetResistantUnits = 183 + MustTargetDazedUnits = 184 + CantTargetDazedUnits = 185 + CantLockdown = 186 + CantMindControl = 187 + MustTargetDestructibles = 188 + CantTargetDestructibles = 189 + MustTargetItems = 190 + CantTargetItems = 191 + NoCalldownAvailable = 192 + WaypointListFull = 193 + MustTargetRace = 194 + CantTargetRace = 195 + MustTargetSimilarUnits = 196 + CantTargetSimilarUnits = 197 + CantFindEnoughTargets = 198 + AlreadySpawningLarva = 199 + CantTargetExhaustedResources = 200 + CantUseMinimap = 201 + CantUseInfoPanel = 202 + OrderQueueIsFull = 203 + CantHarvestThatResource = 204 + HarvestersNotRequired = 205 + AlreadyTargeted = 206 + CantAttackWeaponsDisabled = 207 + CouldntReachTarget = 208 + TargetIsOutOfRange = 209 + TargetIsTooClose = 210 + TargetIsOutOfArc = 211 + CantFindTeleportLocation = 212 + InvalidItemClass = 213 + CantFindCancelOrder = 214 # Module-level dictionaries race_worker: dict[Race, UnitTypeId] diff --git a/sc2/game_data.py b/sc2/game_data.py index 3bc4fc78..1a60963e 100644 --- a/sc2/game_data.py +++ b/sc2/game_data.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from functools import lru_cache +from s2clientprotocol import data_pb2, sc2api_pb2 from sc2.data import Attribute, Race from sc2.ids.ability_id import AbilityId from sc2.ids.unit_typeid import UnitTypeId @@ -20,7 +21,7 @@ class GameData: - def __init__(self, data) -> None: + def __init__(self, data: sc2api_pb2.ResponseData) -> None: """ :param data: """ @@ -77,14 +78,14 @@ class AbilityData: ability_ids: list[int] = [ability_id.value for ability_id in AbilityId][1:] # sorted list @classmethod - def id_exists(cls, ability_id): + def id_exists(cls, ability_id: data_pb2.AbilityData | int) -> bool: assert isinstance(ability_id, int), f"Wrong type: {ability_id} is not int" if ability_id == 0: return False i = bisect_left(cls.ability_ids, ability_id) # quick binary search return i != len(cls.ability_ids) and cls.ability_ids[i] == ability_id - def __init__(self, game_data, proto) -> None: + def __init__(self, game_data: GameData, proto: data_pb2.AbilityData) -> None: self._game_data = game_data self._proto = proto @@ -131,7 +132,7 @@ def cost(self) -> Cost: class UnitTypeData: - def __init__(self, game_data: GameData, proto) -> None: + def __init__(self, game_data: GameData, proto: data_pb2.UnitTypeData) -> None: """ :param game_data: :param proto: @@ -172,12 +173,10 @@ def footprint_radius(self) -> float | None: return self.creation_ability._proto.footprint_radius @property - # pyre-ignore[11] def attributes(self) -> list[Attribute]: - return self._proto.attributes + return [Attribute(i) for i in self._proto.attributes] - def has_attribute(self, attr) -> bool: - # pyre-ignore[6] + def has_attribute(self, attr: Attribute) -> bool: assert isinstance(attr, Attribute) return attr in self.attributes @@ -237,7 +236,7 @@ def cost(self) -> Cost: def cost_zerg_corrected(self) -> Cost: """This returns 25 for extractor and 200 for spawning pool instead of 75 and 250 respectively""" # pyre-ignore[16] - if self.race == Race.Zerg and Attribute.Structure.value in self.attributes: + if self.race.value == Race.Zerg and Attribute.Structure in self.attributes: return Cost(self._proto.mineral_cost - 50, self._proto.vespene_cost, self._proto.build_time) return self.cost @@ -280,7 +279,7 @@ def morph_cost(self) -> Cost | None: class UpgradeData: - def __init__(self, game_data: GameData, proto) -> None: + def __init__(self, game_data: GameData, proto: data_pb2.UpgradeData) -> None: """ :param game_data: :param proto: diff --git a/sc2/py.typed b/sc2/py.typed index e69de29b..d360fd84 100644 --- a/sc2/py.typed +++ b/sc2/py.typed @@ -0,0 +1 @@ +# Required by https://peps.python.org/pep-0561/#packaging-type-information diff --git a/sc2/sc2process.py b/sc2/sc2process.py index 846dc480..3fa1777a 100644 --- a/sc2/sc2process.py +++ b/sc2/sc2process.py @@ -2,7 +2,6 @@ import asyncio import os -import os.path import shutil import signal import subprocess @@ -143,7 +142,6 @@ def find_data_hash(self, target_sc2_version: str) -> str | None: def find_base_dir(self, target_sc2_version: str) -> str | None: """Returns the base directory from the matching version string.""" - version: dict for version in self.versions: if version["label"] == target_sc2_version: return "Base" + str(version["base-version"]) diff --git a/sc2/score.py b/sc2/score.py index 9b8f5f2c..18df8f38 100644 --- a/sc2/score.py +++ b/sc2/score.py @@ -1,14 +1,17 @@ +from s2clientprotocol import score_pb2 + + class ScoreDetails: """Accessable in self.state.score during step function For more information, see https://github.com/Blizzard/s2client-proto/blob/master/s2clientprotocol/score.proto """ - def __init__(self, proto) -> None: + def __init__(self, proto: score_pb2.Score) -> None: self._data = proto self._proto = proto.score_details @property - def summary(self): + def summary(self) -> list[list[int | float]]: """ TODO this is super ugly, how can we improve this summary? Print summary to file with: diff --git a/sc2/unit.py b/sc2/unit.py index 05949505..ab8dab86 100644 --- a/sc2/unit.py +++ b/sc2/unit.py @@ -163,37 +163,37 @@ def tag(self) -> int: @property def is_structure(self) -> bool: """Checks if the unit is a structure.""" - return IS_STRUCTURE in self._type_data.attributes + return IS_STRUCTURE in self._type_data._proto.attributes @property def is_light(self) -> bool: """Checks if the unit has the 'light' attribute.""" - return IS_LIGHT in self._type_data.attributes + return IS_LIGHT in self._type_data._proto.attributes @property def is_armored(self) -> bool: """Checks if the unit has the 'armored' attribute.""" - return IS_ARMORED in self._type_data.attributes + return IS_ARMORED in self._type_data._proto.attributes @property def is_biological(self) -> bool: """Checks if the unit has the 'biological' attribute.""" - return IS_BIOLOGICAL in self._type_data.attributes + return IS_BIOLOGICAL in self._type_data._proto.attributes @property def is_mechanical(self) -> bool: """Checks if the unit has the 'mechanical' attribute.""" - return IS_MECHANICAL in self._type_data.attributes + return IS_MECHANICAL in self._type_data._proto.attributes @property def is_massive(self) -> bool: """Checks if the unit has the 'massive' attribute.""" - return IS_MASSIVE in self._type_data.attributes + return IS_MASSIVE in self._type_data._proto.attributes @property def is_psionic(self) -> bool: """Checks if the unit has the 'psionic' attribute.""" - return IS_PSIONIC in self._type_data.attributes + return IS_PSIONIC in self._type_data._proto.attributes @cached_property def tech_alias(self) -> list[UnitTypeId] | None: diff --git a/sc2/units.py b/sc2/units.py index c41d8ed7..1871dfc6 100644 --- a/sc2/units.py +++ b/sc2/units.py @@ -15,7 +15,7 @@ from sc2.bot_ai import BotAI -class Units(list): +class Units(list[Unit]): """A collection of Unit objects. Makes it easy to select units by selectors.""" @classmethod diff --git a/test/autotest_bot.py b/test/autotest_bot.py index 10abb82e..a6a0dc23 100644 --- a/test/autotest_bot.py +++ b/test/autotest_bot.py @@ -458,7 +458,7 @@ async def test_botai_actions12(self): # Pick scv scv: Unit = self.workers.random # Pick location to build depot on - placement_position: Point2 = await self.find_placement( + placement_position: Point2 | None = await self.find_placement( UnitTypeId.SUPPLYDEPOT, near=self.townhalls.random.position ) if placement_position: From e4eaa30fba844375aec57ab4e2a34f8cc4032fa7 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Mon, 17 Nov 2025 17:33:55 +0100 Subject: [PATCH 17/34] Apply more type hints --- sc2/action.py | 5 +---- sc2/client.py | 5 +++-- sc2/expiring_dict.py | 26 +++++++++++++------------- sc2/player.py | 24 +++++++++++++----------- sc2/unit.py | 7 ++++--- 5 files changed, 34 insertions(+), 33 deletions(-) diff --git a/sc2/action.py b/sc2/action.py index 0500309e..d0e534ba 100644 --- a/sc2/action.py +++ b/sc2/action.py @@ -5,7 +5,6 @@ # pyre-ignore[21] from s2clientprotocol import raw_pb2 as raw_pb - from sc2.position import Point2 from sc2.unit import Unit @@ -14,7 +13,7 @@ from sc2.unit_command import UnitCommand -def combine_actions(action_iter): +def combine_actions(action_iter: list[UnitCommand]): """ Example input: [ @@ -57,7 +56,6 @@ def combine_actions(action_iter): I imagine the same thing would happen to certain other abilities: Battlecruiser yamato on same target, queen transfuse on same target, ghost snipe on same target, all build commands with the same unit type and also all morphs (zergling to banelings) However, other abilities can and should be grouped, see constants.py 'COMBINEABLE_ABILITIES' """ - u: UnitCommand if target is None: for u in items: cmd = raw_pb.ActionRawUnitCommand( @@ -73,7 +71,6 @@ def combine_actions(action_iter): target_world_space_pos=target.as_Point2D, ) yield raw_pb.ActionRaw(unit_command=cmd) - elif isinstance(target, Unit): for u in items: cmd = raw_pb.ActionRawUnitCommand( diff --git a/sc2/client.py b/sc2/client.py index dd5fedad..cde42374 100644 --- a/sc2/client.py +++ b/sc2/client.py @@ -22,6 +22,7 @@ from sc2.protocol import ConnectionAlreadyClosedError, Protocol, ProtocolError from sc2.renderer import Renderer from sc2.unit import Unit +from sc2.unit_command import UnitCommand from sc2.units import Units @@ -193,9 +194,9 @@ async def get_game_info(self) -> GameInfo: result = await self._execute(game_info=sc_pb.RequestGameInfo()) return GameInfo(result.game_info) - async def actions(self, actions, return_successes: bool = False): + async def actions(self, actions: list[UnitCommand], return_successes: bool = False) -> list[ActionResult]: if not actions: - return None + return [] if not isinstance(actions, list): actions = [actions] diff --git a/sc2/expiring_dict.py b/sc2/expiring_dict.py index ebbcb23c..7c6cc2c0 100644 --- a/sc2/expiring_dict.py +++ b/sc2/expiring_dict.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections import OrderedDict -from collections.abc import Iterable +from collections.abc import Hashable, Iterable from threading import RLock from typing import TYPE_CHECKING, Any @@ -10,7 +10,7 @@ from sc2.bot_ai import BotAI -class ExpiringDict(OrderedDict): +class ExpiringDict(OrderedDict[Hashable, Any]): """ An expiring dict that uses the bot.state.game_loop to only return items that are valid for a specific amount of time. @@ -45,7 +45,7 @@ def frame(self) -> int: # pyre-ignore[16] return self.bot.state.game_loop - def __contains__(self, key) -> bool: + def __contains__(self, key: Hashable) -> bool: """Return True if dict has key, else False, e.g. 'key in dict'""" with self.lock: if OrderedDict.__contains__(self, key): @@ -56,7 +56,7 @@ def __contains__(self, key) -> bool: del self[key] return False - def __getitem__(self, key, with_age: bool = False) -> Any: + def __getitem__(self, key: Hashable, with_age: bool = False) -> Any: """Return the item of the dict using d[key]""" with self.lock: # Each item is a list of [value, frame time] @@ -68,7 +68,7 @@ def __getitem__(self, key, with_age: bool = False) -> Any: OrderedDict.__delitem__(self, key) raise KeyError(key) - def __setitem__(self, key, value) -> None: + def __setitem__(self, key: Hashable, value: Any) -> None: """Set d[key] = value""" with self.lock: OrderedDict.__setitem__(self, key, (value, self.frame)) @@ -83,10 +83,10 @@ def __repr__(self) -> str: print_str = ", ".join(print_list) return f"ExpiringDict({print_str})" - def __str__(self): + def __str__(self) -> str: return self.__repr__() - def __iter__(self): + def __iter__(self) -> Iterable[Hashable]: """Override 'for key in dict:'""" with self.lock: return self.keys() @@ -101,7 +101,7 @@ def __len__(self) -> int: count += 1 return count - def pop(self, key, default=None, with_age: bool = False): + def pop(self, key: Hashable, default: Any = None, with_age: bool = False): """Return the item and remove it""" with self.lock: if OrderedDict.__contains__(self, key): @@ -118,7 +118,7 @@ def pop(self, key, default=None, with_age: bool = False): return default, self.frame return default - def get(self, key, default=None, with_age: bool = False): + def get(self, key: Hashable, default: Any = None, with_age: bool = False): """Return the value for key if key is in dict, else default""" with self.lock: if OrderedDict.__contains__(self, key): @@ -134,26 +134,26 @@ def get(self, key, default=None, with_age: bool = False): return None return None - def update(self, other_dict: dict) -> None: + def update(self, other_dict: dict[Hashable, Any]) -> None: with self.lock: for key, value in other_dict.items(): self[key] = value - def items(self) -> Iterable: + def items(self) -> Iterable[tuple[Hashable, Any]]: """Return iterator of zipped list [keys, values]""" with self.lock: for key, value in OrderedDict.items(self): if self.frame - value[1] < self.max_age: yield key, value[0] - def keys(self) -> Iterable: + def keys(self) -> Iterable[Hashable]: """Return iterator of keys""" with self.lock: for key, value in OrderedDict.items(self): if self.frame - value[1] < self.max_age: yield key - def values(self) -> Iterable: + def values(self) -> Iterable[Any]: """Return iterator of values""" with self.lock: for value in OrderedDict.values(self): diff --git a/sc2/player.py b/sc2/player.py index 935535c0..7e7d5255 100644 --- a/sc2/player.py +++ b/sc2/player.py @@ -13,10 +13,10 @@ class AbstractPlayer(ABC): def __init__( self, p_type: PlayerType, - race: Race = None, + race: Race | None = None, name: str | None = None, - difficulty=None, - ai_build=None, + difficulty: Difficulty | None = None, + ai_build: AIBuild | None = None, fullscreen: bool = False, ) -> None: assert isinstance(p_type, PlayerType), f"p_type is of type {type(p_type)}" @@ -51,7 +51,7 @@ def needs_sc2(self) -> bool: class Human(AbstractPlayer): - def __init__(self, race, name: str | None = None, fullscreen: bool = False) -> None: + def __init__(self, race: Race, name: str | None = None, fullscreen: bool = False) -> None: super().__init__(PlayerType.Participant, race, name=name, fullscreen=fullscreen) def __str__(self) -> str: @@ -61,7 +61,7 @@ def __str__(self) -> str: class Bot(AbstractPlayer): - def __init__(self, race, ai, name: str | None = None, fullscreen: bool = False) -> None: + def __init__(self, race: Race, ai: BotAI, name: str | None = None, fullscreen: bool = False) -> None: """ AI can be None if this player object is just used to inform the server about player types. @@ -77,7 +77,9 @@ def __str__(self) -> str: class Computer(AbstractPlayer): - def __init__(self, race, difficulty=Difficulty.Easy, ai_build=AIBuild.RandomBuild) -> None: + def __init__( + self, race: Race, difficulty: Difficulty = Difficulty.Easy, ai_build: AIBuild = AIBuild.RandomBuild + ) -> None: super().__init__(PlayerType.Computer, race, difficulty=difficulty, ai_build=ai_build) def __str__(self) -> str: @@ -96,12 +98,12 @@ class Player(AbstractPlayer): def __init__( self, player_id: int, - p_type, - requested_race, - difficulty=None, - actual_race=None, + p_type: PlayerType, + requested_race: Race, + difficulty: Difficulty | None = None, + actual_race: Race = None, name: str | None = None, - ai_build=None, + ai_build: AIBuild | None = None, ) -> None: super().__init__(p_type, requested_race, difficulty=difficulty, name=name, ai_build=ai_build) self.id: int = player_id diff --git a/sc2/unit.py b/sc2/unit.py index ab8dab86..9da68762 100644 --- a/sc2/unit.py +++ b/sc2/unit.py @@ -63,6 +63,7 @@ if TYPE_CHECKING: from sc2.bot_ai import BotAI + from sc2.bot_ai_internal import BotAIInternal from sc2.game_data import AbilityData, UnitTypeData @@ -108,7 +109,7 @@ class Unit: def __init__( self, proto_data: raw_pb2.Unit, - bot_object: BotAI, + bot_object: BotAI | BotAIInternal, distance_calculation_index: int = -1, base_build: int = -1, ) -> None: @@ -119,7 +120,7 @@ def __init__( :param base_build: """ self._proto = proto_data - self._bot_object: BotAI = bot_object + self._bot_object = bot_object self.game_loop: int = bot_object.state.game_loop self.base_build = base_build # Index used in the 2D numpy array to access the 2D distance between two units @@ -706,7 +707,7 @@ def calculate_damage_vs_target( # TODO: hardcode hellbats when they have blueflame or attack upgrades for bonus in weapon.damage_bonus: # More about damage bonus https://github.com/Blizzard/s2client-proto/blob/b73eb59ac7f2c52b2ca585db4399f2d3202e102a/s2clientprotocol/data.proto#L55 - if bonus.attribute in target._type_data.attributes: + if bonus.attribute in target._type_data._proto.attributes: bonus_damage_per_upgrade = ( 0 if not self.attack_upgrade_level From 6fdaf69dff7db7a83e4c239381a11953dd26a4cf Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Mon, 17 Nov 2025 17:34:56 +0100 Subject: [PATCH 18/34] Replace list with iterable --- s2clientprotocol/data_pb2.pyi | 17 +++--- s2clientprotocol/debug_pb2.pyi | 21 +++---- s2clientprotocol/query_pb2.pyi | 30 +++++----- s2clientprotocol/raw_pb2.pyi | 57 +++++++++---------- s2clientprotocol/sc2api_pb2.pyi | 95 ++++++++++++++++---------------- s2clientprotocol/spatial_pb2.pyi | 5 +- s2clientprotocol/ui_pb2.pyi | 25 +++++---- 7 files changed, 129 insertions(+), 121 deletions(-) diff --git a/s2clientprotocol/data_pb2.pyi b/s2clientprotocol/data_pb2.pyi index 4648b530..00e286d8 100644 --- a/s2clientprotocol/data_pb2.pyi +++ b/s2clientprotocol/data_pb2.pyi @@ -1,4 +1,5 @@ from enum import Enum +from typing import Iterable from google.protobuf.message import Message @@ -70,7 +71,7 @@ class TargetType(Enum): class Weapon(Message): type: int damage: float - damage_bonus: list[DamageBonus] + damage_bonus: Iterable[DamageBonus] attacks: int range: float speed: float @@ -78,7 +79,7 @@ class Weapon(Message): self, type: int = ..., damage: float = ..., - damage_bonus: list[DamageBonus] = ..., + damage_bonus: Iterable[DamageBonus] = ..., attacks: int = ..., range: float = ..., speed: float = ..., @@ -99,14 +100,14 @@ class UnitTypeData(Message): has_vespene: bool has_minerals: bool sight_range: float - tech_alias: list[int] + tech_alias: Iterable[int] unit_alias: int tech_requirement: int require_attached: bool - attributes: list[int] + attributes: Iterable[int] movement_speed: float armor: float - weapons: list[Weapon] + weapons: Iterable[Weapon] def __init__( self, unit_id: int = ..., @@ -123,14 +124,14 @@ class UnitTypeData(Message): has_vespene: bool = ..., has_minerals: bool = ..., sight_range: float = ..., - tech_alias: list[int] = ..., + tech_alias: Iterable[int] = ..., unit_alias: int = ..., tech_requirement: int = ..., require_attached: bool = ..., - attributes: list[int] = ..., + attributes: Iterable[int] = ..., movement_speed: float = ..., armor: float = ..., - weapons: list[Weapon] = ..., + weapons: Iterable[Weapon] = ..., ) -> None: ... class UpgradeData(Message): diff --git a/s2clientprotocol/debug_pb2.pyi b/s2clientprotocol/debug_pb2.pyi index fe710158..40f88db4 100644 --- a/s2clientprotocol/debug_pb2.pyi +++ b/s2clientprotocol/debug_pb2.pyi @@ -1,4 +1,5 @@ from enum import Enum +from typing import Iterable from google.protobuf.message import Message @@ -26,16 +27,16 @@ class DebugCommand(Message): ) -> None: ... class DebugDraw(Message): - text: list[DebugText] - lines: list[DebugLine] - boxes: list[DebugBox] - spheres: list[DebugSphere] + text: Iterable[DebugText] + lines: Iterable[DebugLine] + boxes: Iterable[DebugBox] + spheres: Iterable[DebugSphere] def __init__( self, - text: list[DebugText] = ..., - lines: list[DebugLine] = ..., - boxes: list[DebugBox] = ..., - spheres: list[DebugSphere] = ..., + text: Iterable[DebugText] = ..., + lines: Iterable[DebugLine] = ..., + boxes: Iterable[DebugBox] = ..., + spheres: Iterable[DebugSphere] = ..., ) -> None: ... class Line(Message): @@ -109,8 +110,8 @@ class DebugCreateUnit(Message): ) -> None: ... class DebugKillUnit(Message): - tag: list[int] - def __init__(self, tag: list[int] = ...) -> None: ... + tag: Iterable[int] + def __init__(self, tag: Iterable[int] = ...) -> None: ... class Test(Enum): hang: int diff --git a/s2clientprotocol/query_pb2.pyi b/s2clientprotocol/query_pb2.pyi index 2ba65f0a..19761aee 100644 --- a/s2clientprotocol/query_pb2.pyi +++ b/s2clientprotocol/query_pb2.pyi @@ -1,29 +1,31 @@ +from typing import Iterable + from google.protobuf.message import Message from .common_pb2 import AvailableAbility, Point2D class RequestQuery(Message): - pathing: list[RequestQueryPathing] - abilities: list[RequestQueryAvailableAbilities] - placements: list[RequestQueryBuildingPlacement] + pathing: Iterable[RequestQueryPathing] + abilities: Iterable[RequestQueryAvailableAbilities] + placements: Iterable[RequestQueryBuildingPlacement] ignore_resource_requirements: bool def __init__( self, - pathing: list[RequestQueryPathing] = ..., - abilities: list[RequestQueryAvailableAbilities] = ..., - placements: list[RequestQueryBuildingPlacement] = ..., + pathing: Iterable[RequestQueryPathing] = ..., + abilities: Iterable[RequestQueryAvailableAbilities] = ..., + placements: Iterable[RequestQueryBuildingPlacement] = ..., ignore_resource_requirements: bool = ..., ) -> None: ... class ResponseQuery(Message): - pathing: list[ResponseQueryPathing] - abilities: list[ResponseQueryAvailableAbilities] - placements: list[ResponseQueryBuildingPlacement] + pathing: Iterable[ResponseQueryPathing] + abilities: Iterable[ResponseQueryAvailableAbilities] + placements: Iterable[ResponseQueryBuildingPlacement] def __init__( self, - pathing: list[ResponseQueryPathing] = ..., - abilities: list[ResponseQueryAvailableAbilities] = ..., - placements: list[ResponseQueryBuildingPlacement] = ..., + pathing: Iterable[ResponseQueryPathing] = ..., + abilities: Iterable[ResponseQueryAvailableAbilities] = ..., + placements: Iterable[ResponseQueryBuildingPlacement] = ..., ) -> None: ... class RequestQueryPathing(Message): @@ -46,12 +48,12 @@ class RequestQueryAvailableAbilities(Message): def __init__(self, unit_tag: int = ...) -> None: ... class ResponseQueryAvailableAbilities(Message): - abilities: list[AvailableAbility] + abilities: Iterable[AvailableAbility] unit_tag: int unit_type_id: int def __init__( self, - abilities: list[AvailableAbility] = ..., + abilities: Iterable[AvailableAbility] = ..., unit_tag: int = ..., unit_type_id: int = ..., ) -> None: ... diff --git a/s2clientprotocol/raw_pb2.pyi b/s2clientprotocol/raw_pb2.pyi index c41ab731..840b1efd 100644 --- a/s2clientprotocol/raw_pb2.pyi +++ b/s2clientprotocol/raw_pb2.pyi @@ -1,4 +1,5 @@ from enum import Enum +from typing import Iterable from google.protobuf.message import Message @@ -10,7 +11,7 @@ class StartRaw(Message): terrain_height: ImageData placement_grid: ImageData playable_area: RectangleI - start_locations: list[Point2D] + start_locations: Iterable[Point2D] def __init__( self, map_size: Size2DI = ..., @@ -18,24 +19,24 @@ class StartRaw(Message): terrain_height: ImageData = ..., placement_grid: ImageData = ..., playable_area: RectangleI = ..., - start_locations: list[Point2D] = ..., + start_locations: Iterable[Point2D] = ..., ) -> None: ... class ObservationRaw(Message): player: PlayerRaw - units: list[Unit] + units: Iterable[Unit] map_state: MapState event: Event - effects: list[Effect] - radar: list[RadarRing] + effects: Iterable[Effect] + radar: Iterable[RadarRing] def __init__( self, player: PlayerRaw = ..., - units: list[Unit] = ..., + units: Iterable[Unit] = ..., map_state: MapState = ..., event: Event = ..., - effects: list[Effect] = ..., - radar: list[RadarRing] = ..., + effects: Iterable[Effect] = ..., + radar: Iterable[RadarRing] = ..., ) -> None: ... class RadarRing(Message): @@ -50,14 +51,14 @@ class PowerSource(Message): def __init__(self, pos: Point = ..., radius: float = ..., tag: int = ...) -> None: ... class PlayerRaw(Message): - power_sources: list[PowerSource] + power_sources: Iterable[PowerSource] camera: Point - upgrade_ids: list[int] + upgrade_ids: Iterable[int] def __init__( self, - power_sources: list[PowerSource] = ..., + power_sources: Iterable[PowerSource] = ..., camera: Point = ..., - upgrade_ids: list[int] = ..., + upgrade_ids: Iterable[int] = ..., ) -> None: ... class UnitOrder(Message): @@ -129,7 +130,7 @@ class Unit(Message): radius: float build_progress: float cloak: int - buff_ids: list[int] + buff_ids: Iterable[int] detect_range: float radar_range: float is_selected: bool @@ -151,9 +152,9 @@ class Unit(Message): is_flying: bool is_burrowed: bool is_hallucination: bool - orders: list[UnitOrder] + orders: Iterable[UnitOrder] add_on_tag: int - passengers: list[PassengerUnit] + passengers: Iterable[PassengerUnit] cargo_space_taken: int cargo_space_max: int assigned_harvesters: int @@ -162,7 +163,7 @@ class Unit(Message): engaged_target_tag: int buff_duration_remain: int buff_duration_max: int - rally_targets: list[RallyTarget] + rally_targets: Iterable[RallyTarget] def __init__( self, display_type: int = ..., @@ -175,7 +176,7 @@ class Unit(Message): radius: float = ..., build_progress: float = ..., cloak: int = ..., - buff_ids: list[int] = ..., + buff_ids: Iterable[int] = ..., detect_range: float = ..., radar_range: float = ..., is_selected: bool = ..., @@ -197,9 +198,9 @@ class Unit(Message): is_flying: bool = ..., is_burrowed: bool = ..., is_hallucination: bool = ..., - orders: list[UnitOrder] = ..., + orders: Iterable[UnitOrder] = ..., add_on_tag: int = ..., - passengers: list[PassengerUnit] = ..., + passengers: Iterable[PassengerUnit] = ..., cargo_space_taken: int = ..., cargo_space_max: int = ..., assigned_harvesters: int = ..., @@ -208,7 +209,7 @@ class Unit(Message): engaged_target_tag: int = ..., buff_duration_remain: int = ..., buff_duration_max: int = ..., - rally_targets: list[RallyTarget] = ..., + rally_targets: Iterable[RallyTarget] = ..., ) -> None: ... class MapState(Message): @@ -217,19 +218,19 @@ class MapState(Message): def __init__(self, visibility: ImageData = ..., creep: ImageData = ...) -> None: ... class Event(Message): - dead_units: list[int] - def __init__(self, dead_units: list[int] = ...) -> None: ... + dead_units: Iterable[int] + def __init__(self, dead_units: Iterable[int] = ...) -> None: ... class Effect(Message): effect_id: int - pos: list[Point2D] + pos: Iterable[Point2D] alliance: int owner: int radius: float def __init__( self, effect_id: int = ..., - pos: list[Point2D] = ..., + pos: Iterable[Point2D] = ..., alliance: int = ..., owner: int = ..., radius: float = ..., @@ -250,14 +251,14 @@ class ActionRawUnitCommand(Message): ability_id: int target_world_space_pos: Point2D target_unit_tag: int - unit_tags: list[int] + unit_tags: Iterable[int] queue_command: bool def __init__( self, ability_id: int = ..., target_world_space_pos: Point2D = ..., target_unit_tag: int = ..., - unit_tags: list[int] = ..., + unit_tags: Iterable[int] = ..., queue_command: bool = ..., ) -> None: ... @@ -267,5 +268,5 @@ class ActionRawCameraMove(Message): class ActionRawToggleAutocast(Message): ability_id: int - unit_tags: list[int] - def __init__(self, ability_id: int = ..., unit_tags: list[int] = ...) -> None: ... + unit_tags: Iterable[int] + def __init__(self, ability_id: int = ..., unit_tags: Iterable[int] = ...) -> None: ... diff --git a/s2clientprotocol/sc2api_pb2.pyi b/s2clientprotocol/sc2api_pb2.pyi index 23ad2d15..76def499 100644 --- a/s2clientprotocol/sc2api_pb2.pyi +++ b/s2clientprotocol/sc2api_pb2.pyi @@ -1,6 +1,7 @@ from __future__ import annotations from enum import Enum +from typing import Iterable from google.protobuf.message import Message @@ -89,7 +90,7 @@ class Response(Message): ping: ResponsePing debug: ResponseDebug id: int - error: list[str] + error: Iterable[str] status: int def __init__( self, @@ -116,7 +117,7 @@ class Response(Message): ping: ResponsePing = ..., debug: ResponseDebug = ..., id: int = ..., - error: list[str] = ..., + error: Iterable[str] = ..., status: int = ..., ) -> None: ... @@ -132,7 +133,7 @@ class Status(Enum): class RequestCreateGame(Message): local_map: LocalMap battlenet_map_name: str - player_setup: list[PlayerSetup] + player_setup: Iterable[PlayerSetup] disable_fog: bool random_seed: int realtime: bool @@ -140,7 +141,7 @@ class RequestCreateGame(Message): self, local_map: LocalMap = ..., battlenet_map_name: str = ..., - player_setup: list[PlayerSetup] = ..., + player_setup: Iterable[PlayerSetup] = ..., disable_fog: bool = ..., random_seed: int = ..., realtime: bool = ..., @@ -171,7 +172,7 @@ class RequestJoinGame(Message): observed_player_id: int options: InterfaceOptions server_ports: PortSet - client_ports: list[PortSet] + client_ports: Iterable[PortSet] shared_port: int player_name: str host_ip: str @@ -181,7 +182,7 @@ class RequestJoinGame(Message): observed_player_id: int = ..., options: InterfaceOptions = ..., server_ports: PortSet = ..., - client_ports: list[PortSet] = ..., + client_ports: Iterable[PortSet] = ..., shared_port: int = ..., player_name: str = ..., host_ip: str = ..., @@ -301,17 +302,17 @@ class RequestGameInfo(Message): class ResponseGameInfo(Message): map_name: str - mod_names: list[str] + mod_names: Iterable[str] local_map_path: str - player_info: list[PlayerInfo] + player_info: Iterable[PlayerInfo] start_raw: StartRaw options: InterfaceOptions def __init__( self, map_name: str = ..., - mod_names: list[str] = ..., + mod_names: Iterable[str] = ..., local_map_path: str = ..., - player_info: list[PlayerInfo] = ..., + player_info: Iterable[PlayerInfo] = ..., start_raw: StartRaw = ..., options: InterfaceOptions = ..., ) -> None: ... @@ -322,18 +323,18 @@ class RequestObservation(Message): def __init__(self, disable_fog: bool = ..., game_loop: int = ...) -> None: ... class ResponseObservation(Message): - actions: list[Action] - action_errors: list[ActionError] + actions: Iterable[Action] + action_errors: Iterable[ActionError] observation: Observation - player_result: list[PlayerResult] - chat: list[ChatReceived] + player_result: Iterable[PlayerResult] + chat: Iterable[ChatReceived] def __init__( self, - actions: list[Action] = ..., - action_errors: list[ActionError] = ..., + actions: Iterable[Action] = ..., + action_errors: Iterable[ActionError] = ..., observation: Observation = ..., - player_result: list[PlayerResult] = ..., - chat: list[ChatReceived] = ..., + player_result: Iterable[PlayerResult] = ..., + chat: Iterable[ChatReceived] = ..., ) -> None: ... class ChatReceived(Message): @@ -342,16 +343,16 @@ class ChatReceived(Message): def __init__(self, player_id: int = ..., message: str = ...) -> None: ... class RequestAction(Message): - actions: list[Action] - def __init__(self, actions: list[Action] = ...) -> None: ... + actions: Iterable[Action] + def __init__(self, actions: Iterable[Action] = ...) -> None: ... class ResponseAction(Message): - result: list[int] - def __init__(self, result: list[int] = ...) -> None: ... + result: Iterable[int] + def __init__(self, result: Iterable[int] = ...) -> None: ... class RequestObserverAction(Message): - actions: list[ObserverAction] - def __init__(self, actions: list[ObserverAction] = ...) -> None: ... + actions: Iterable[ObserverAction] + def __init__(self, actions: Iterable[ObserverAction] = ...) -> None: ... class ResponseObserverAction(Message): def __init__(self) -> None: ... @@ -380,18 +381,18 @@ class RequestData(Message): ) -> None: ... class ResponseData(Message): - abilities: list[AbilityData] - units: list[UnitTypeData] - upgrades: list[UpgradeData] - buffs: list[BuffData] - effects: list[EffectData] + abilities: Iterable[AbilityData] + units: Iterable[UnitTypeData] + upgrades: Iterable[UpgradeData] + buffs: Iterable[BuffData] + effects: Iterable[EffectData] def __init__( self, - abilities: list[AbilityData] = ..., - units: list[UnitTypeData] = ..., - upgrades: list[UpgradeData] = ..., - buffs: list[BuffData] = ..., - effects: list[EffectData] = ..., + abilities: Iterable[AbilityData] = ..., + units: Iterable[UnitTypeData] = ..., + upgrades: Iterable[UpgradeData] = ..., + buffs: Iterable[BuffData] = ..., + effects: Iterable[EffectData] = ..., ) -> None: ... class RequestSaveReplay(Message): @@ -435,7 +436,7 @@ class ResponseReplayInfo(Message): map_name: str local_map_path: str - player_info: list[PlayerInfoExtra] + player_info: Iterable[PlayerInfoExtra] game_duration_loops: int game_duration_seconds: float game_version: str @@ -448,7 +449,7 @@ class ResponseReplayInfo(Message): self, map_name: str = ..., local_map_path: str = ..., - player_info: list[PlayerInfoExtra] = ..., + player_info: Iterable[PlayerInfoExtra] = ..., game_duration_loops: int = ..., game_duration_seconds: float = ..., game_version: str = ..., @@ -463,9 +464,9 @@ class RequestAvailableMaps(Message): def __init__(self) -> None: ... class ResponseAvailableMaps(Message): - local_map_paths: list[str] - battlenet_map_names: list[str] - def __init__(self, local_map_paths: list[str] = ..., battlenet_map_names: list[str] = ...) -> None: ... + local_map_paths: Iterable[str] + battlenet_map_names: Iterable[str] + def __init__(self, local_map_paths: Iterable[str] = ..., battlenet_map_names: Iterable[str] = ...) -> None: ... class RequestSaveMap(Message): map_path: str @@ -496,8 +497,8 @@ class ResponsePing(Message): ) -> None: ... class RequestDebug(Message): - debug: list[DebugCommand] - def __init__(self, debug: list[DebugCommand] = ...) -> None: ... + debug: Iterable[DebugCommand] + def __init__(self, debug: Iterable[DebugCommand] = ...) -> None: ... class ResponseDebug(Message): def __init__(self) -> None: ... @@ -629,8 +630,8 @@ class PlayerCommon(Message): class Observation(Message): game_loop: int player_common: PlayerCommon - alerts: list[int] - abilities: list[AvailableAbility] + alerts: Iterable[int] + abilities: Iterable[AvailableAbility] score: Score raw_data: ObservationRaw feature_layer_data: ObservationFeatureLayer @@ -640,8 +641,8 @@ class Observation(Message): self, game_loop: int = ..., player_common: PlayerCommon = ..., - alerts: list[int] = ..., - abilities: list[AvailableAbility] = ..., + alerts: Iterable[int] = ..., + abilities: Iterable[AvailableAbility] = ..., score: Score = ..., raw_data: ObservationRaw = ..., feature_layer_data: ObservationFeatureLayer = ..., @@ -708,8 +709,8 @@ class ActionObserverCameraFollowPlayer(Message): def __init__(self, player_id: int = ...) -> None: ... class ActionObserverCameraFollowUnits(Message): - unit_tags: list[int] - def __init__(self, unit_tags: list[int] = ...) -> None: ... + unit_tags: Iterable[int] + def __init__(self, unit_tags: Iterable[int] = ...) -> None: ... class Alert(Enum): AlertError: int diff --git a/s2clientprotocol/spatial_pb2.pyi b/s2clientprotocol/spatial_pb2.pyi index 0fad99e6..3e5925e8 100644 --- a/s2clientprotocol/spatial_pb2.pyi +++ b/s2clientprotocol/spatial_pb2.pyi @@ -1,6 +1,7 @@ from __future__ import annotations from enum import Enum +from typing import Iterable from google.protobuf.message import Message @@ -148,6 +149,6 @@ class ActionSpatialUnitSelectionPoint(Message): def __init__(self, selection_screen_coord: PointI = ..., type: int = ...) -> None: ... class ActionSpatialUnitSelectionRect(Message): - selection_screen_coord: list[RectangleI] + selection_screen_coord: Iterable[RectangleI] selection_add: bool - def __init__(self, selection_screen_coord: list[RectangleI] = ..., selection_add: bool = ...) -> None: ... + def __init__(self, selection_screen_coord: Iterable[RectangleI] = ..., selection_add: bool = ...) -> None: ... diff --git a/s2clientprotocol/ui_pb2.pyi b/s2clientprotocol/ui_pb2.pyi index 236bb70c..589d1b33 100644 --- a/s2clientprotocol/ui_pb2.pyi +++ b/s2clientprotocol/ui_pb2.pyi @@ -1,18 +1,19 @@ from __future__ import annotations from enum import Enum +from typing import Iterable from google.protobuf.message import Message class ObservationUI(Message): - groups: list[ControlGroup] + groups: Iterable[ControlGroup] single: SinglePanel multi: MultiPanel cargo: CargoPanel production: ProductionPanel def __init__( self, - groups: list[ControlGroup] = ..., + groups: Iterable[ControlGroup] = ..., single: SinglePanel = ..., multi: MultiPanel = ..., cargo: CargoPanel = ..., @@ -62,28 +63,28 @@ class SinglePanel(Message): attack_upgrade_level: int armor_upgrade_level: int shield_upgrade_level: int - buffs: list[int] + buffs: Iterable[int] def __init__( self, unit: UnitInfo = ..., attack_upgrade_level: int = ..., armor_upgrade_level: int = ..., shield_upgrade_level: int = ..., - buffs: list[int] = ..., + buffs: Iterable[int] = ..., ) -> None: ... class MultiPanel(Message): - units: list[UnitInfo] - def __init__(self, units: list[UnitInfo] = ...) -> None: ... + units: Iterable[UnitInfo] + def __init__(self, units: Iterable[UnitInfo] = ...) -> None: ... class CargoPanel(Message): unit: UnitInfo - passengers: list[UnitInfo] + passengers: Iterable[UnitInfo] slots_available: int def __init__( self, unit: UnitInfo = ..., - passengers: list[UnitInfo] = ..., + passengers: Iterable[UnitInfo] = ..., slots_available: int = ..., ) -> None: ... @@ -94,13 +95,13 @@ class BuildItem(Message): class ProductionPanel(Message): unit: UnitInfo - build_queue: list[UnitInfo] - production_queue: list[BuildItem] + build_queue: Iterable[UnitInfo] + production_queue: Iterable[BuildItem] def __init__( self, unit: UnitInfo = ..., - build_queue: list[UnitInfo] = ..., - production_queue: list[BuildItem] = ..., + build_queue: Iterable[UnitInfo] = ..., + production_queue: Iterable[BuildItem] = ..., ) -> None: ... class ActionUI(Message): From 3295043d139311a6c2f89d694f6a796850677302 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Mon, 17 Nov 2025 17:54:14 +0100 Subject: [PATCH 19/34] Add typing to position.py --- sc2/position.py | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/sc2/position.py b/sc2/position.py index 7f7da4c4..71f49856 100644 --- a/sc2/position.py +++ b/sc2/position.py @@ -5,7 +5,7 @@ import math import random from collections.abc import Iterable -from typing import TYPE_CHECKING, SupportsFloat, SupportsIndex +from typing import TYPE_CHECKING, Self, SupportsFloat, SupportsIndex # pyre-fixme[21] from s2clientprotocol import common_pb2 as common_pb @@ -21,33 +21,35 @@ def _sign(num: SupportsFloat | SupportsIndex) -> float: return math.copysign(1, num) -class Pointlike(tuple): +class Pointlike(tuple[float, float]): @property - def position(self) -> Pointlike: + def position(self) -> Self: return self - def distance_to(self, target: Unit | Point2) -> float: + def distance_to(self, target: Unit | Pointlike) -> float: """Calculate a single distance from a point or unit to another point or unit :param target:""" p = target.position return math.hypot(self[0] - p[0], self[1] - p[1]) - def distance_to_point2(self, p: Point2 | tuple[float, float]) -> float: + def distance_to_point2(self, p: tuple[float, float] | tuple[float, float, float]) -> float: """Same as the function above, but should be a bit faster because of the dropped asserts and conversion. :param p:""" return math.hypot(self[0] - p[0], self[1] - p[1]) - def _distance_squared(self, p2: Point2) -> float: + def _distance_squared(self, p2: tuple[float, float] | tuple[float, float, float]) -> float: """Function used to not take the square root as the distances will stay proportionally the same. This is to speed up the sorting process. :param p2:""" return (self[0] - p2[0]) ** 2 + (self[1] - p2[1]) ** 2 - def sort_by_distance(self, ps: Units | Iterable[Point2]) -> list[Point2]: + def sort_by_distance( + self, ps: Units | Iterable[tuple[float, float] | tuple[float, float, float]] + ) -> list[Unit | tuple[float, float] | tuple[float, float, float]]: """This returns the target points sorted as list. You should not pass a set or dict since those are not sortable. If you want to sort your units towards a point, use 'units.sorted_by_distance_to(point)' instead. @@ -55,7 +57,7 @@ def sort_by_distance(self, ps: Units | Iterable[Point2]) -> list[Point2]: :param ps:""" return sorted(ps, key=lambda p: self.distance_to_point2(p.position)) - def closest(self, ps: Units | Iterable[Point2]) -> Unit | Point2: + def closest(self, ps: Units | Iterable[Point2]) -> Unit | Pointlike: """This function assumes the 2d distance is meant :param ps:""" @@ -96,14 +98,14 @@ def distance_to_furthest(self, ps: Units | Iterable[Point2]) -> float: furthest_distance = distance return furthest_distance - def offset(self, p) -> Pointlike: + def offset(self, p: tuple[float, float]) -> Self: """ :param p: """ return self.__class__(a + b for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0)) - def unit_axes_towards(self, p) -> Pointlike: + def unit_axes_towards(self, p: tuple[float, float]) -> Self: """ :param p: @@ -130,7 +132,7 @@ def towards(self, p: Unit | Pointlike, distance: int | float = 1, limit: bool = a + (b - a) / d * distance for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0) ) - def __eq__(self, other: object) -> bool: + def __eq__(self, other: tuple[float, float] | tuple[float, float, float]) -> bool: try: return all(abs(a - b) <= EPSILON for a, b in itertools.zip_longest(self, other, fillvalue=0)) except TypeError: @@ -156,10 +158,9 @@ def as_Point2D(self) -> common_pb.Point2D: return common_pb.Point2D(x=self.x, y=self.y) @property - # pyre-fixme[11] def as_PointI(self) -> common_pb.PointI: """Represents points on the minimap. Values must be between 0 and 64.""" - return common_pb.PointI(x=self.x, y=self.y) + return common_pb.PointI(x=int(self[0]), y=int(self[1])) @property def rounded(self) -> Point2: @@ -201,15 +202,16 @@ def round(self, decimals: int) -> Point2: def offset(self, p: Point2) -> Point2: return Point2((self[0] + p[0], self[1] + p[1])) - def random_on_distance(self, distance) -> Point2: + def random_on_distance(self, distance: float | tuple[float, float] | list[float]) -> Point2: if isinstance(distance, (tuple, list)): # interval - distance = distance[0] + random.random() * (distance[1] - distance[0]) - - assert distance > 0, "Distance is not greater than 0" + dist = distance[0] + random.random() * (distance[1] - distance[0]) + else: + dist = distance + assert dist > 0, "Distance is not greater than 0" angle = random.random() * 2 * math.pi dx, dy = math.cos(angle), math.sin(angle) - return Point2((self.x + dx * distance, self.y + dy * distance)) + return Point2((self.x + dx * dist, self.y + dy * dist)) def towards_with_random_angle( self, @@ -250,7 +252,7 @@ def circle_intersection(self, p: Point2, r: int | float) -> set[Point2]: return {intersect1, intersect2} @property - def neighbors4(self) -> set: + def neighbors4(self) -> set[Point2]: return { Point2((self.x - 1, self.y)), Point2((self.x + 1, self.y)), @@ -259,7 +261,7 @@ def neighbors4(self) -> set: } @property - def neighbors8(self) -> set: + def neighbors8(self) -> set[Point2]: return self.neighbors4 | { Point2((self.x - 1, self.y - 1)), Point2((self.x - 1, self.y + 1)), @@ -373,7 +375,7 @@ def height(self) -> float: return self[1] -class Rect(tuple): +class Rect(tuple[float, float, float, float]): @classmethod def from_proto(cls, data: common_pb.RectangleI) -> Rect: """ @@ -416,5 +418,5 @@ def size(self) -> Size: def center(self) -> Point2: return Point2((self.x + self.width / 2, self.y + self.height / 2)) - def offset(self, p) -> Rect: + def offset(self, p: tuple[float, float]) -> Rect: return self.__class__((self[0] + p[0], self[1] + p[1], self[2], self[3])) From 6116e35a69d968e675628a0b65e249576b80b905 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 01:57:45 +0100 Subject: [PATCH 20/34] Improve typing on example bots --- examples/arcade_bot.py | 2 +- examples/competitive/bot.py | 2 +- examples/distributed_workers.py | 2 +- examples/fastreload.py | 7 +++++-- examples/host_external_norestart.py | 4 ++-- examples/protoss/cannon_rush.py | 2 +- examples/protoss/find_adept_shades.py | 2 +- examples/protoss/threebase_voidray.py | 2 +- examples/protoss/warpgate_push.py | 9 +++++---- examples/simulate_fight_scenario.py | 4 ++-- examples/terran/cyclone_push.py | 10 +++++----- examples/terran/mass_reaper.py | 22 ++++++++++++---------- examples/terran/onebase_battlecruiser.py | 10 +++++----- examples/terran/proxy_rax.py | 2 +- examples/terran/ramp_wall.py | 6 +++--- examples/too_slow_bot.py | 2 +- examples/worker_rush.py | 2 +- examples/zerg/banes_banes_banes.py | 2 +- examples/zerg/expand_everywhere.py | 2 +- examples/zerg/hydralisk_push.py | 2 +- examples/zerg/onebase_broodlord.py | 2 +- examples/zerg/worker_split.py | 2 +- examples/zerg/zerg_rush.py | 6 +++--- 23 files changed, 56 insertions(+), 50 deletions(-) diff --git a/examples/arcade_bot.py b/examples/arcade_bot.py index 32bbf22c..811ee944 100644 --- a/examples/arcade_bot.py +++ b/examples/arcade_bot.py @@ -42,7 +42,7 @@ async def on_start(self): await self.chat_send("Edit this message for automatic chat commands.") self.client.game_step = 2 - async def on_step(self, iteration): + async def on_step(self, iteration: int): # do marine micro vs zerglings for unit in self.units(UnitTypeId.MARINE): if self.enemy_units: diff --git a/examples/competitive/bot.py b/examples/competitive/bot.py index 5170635a..253337b6 100644 --- a/examples/competitive/bot.py +++ b/examples/competitive/bot.py @@ -7,7 +7,7 @@ async def on_start(self): print("Game started") # Do things here before the game starts - async def on_step(self, iteration): + async def on_step(self, iteration: int): # Populate this function with whatever your bot should do! pass diff --git a/examples/distributed_workers.py b/examples/distributed_workers.py index 95d3d4af..9e7940e5 100644 --- a/examples/distributed_workers.py +++ b/examples/distributed_workers.py @@ -8,7 +8,7 @@ class TerranBot(BotAI): - async def on_step(self, iteration): + async def on_step(self, iteration: int): await self.distribute_workers() await self.build_supply() await self.build_workers() diff --git a/examples/fastreload.py b/examples/fastreload.py index 3cacde4a..4fc5439a 100644 --- a/examples/fastreload.py +++ b/examples/fastreload.py @@ -4,11 +4,14 @@ from sc2 import maps from sc2.data import Difficulty, Race from sc2.main import _host_game_iter -from sc2.player import Bot, Computer +from sc2.player import AbstractPlayer, Bot, Computer def main(): - player_config = [Bot(Race.Zerg, zerg_rush.ZergRushBot()), Computer(Race.Terran, Difficulty.Medium)] + player_config: list[AbstractPlayer] = [ + Bot(Race.Zerg, zerg_rush.ZergRushBot()), + Computer(Race.Terran, Difficulty.Medium), + ] gen = _host_game_iter(maps.get("Abyssal Reef LE"), player_config, realtime=False) diff --git a/examples/host_external_norestart.py b/examples/host_external_norestart.py index eb2558a9..c5626ac0 100644 --- a/examples/host_external_norestart.py +++ b/examples/host_external_norestart.py @@ -1,13 +1,13 @@ -import sc2 from examples.zerg.zerg_rush import ZergRushBot from sc2 import maps from sc2.data import Race from sc2.main import _host_game_iter from sc2.player import Bot +from sc2.portconfig import Portconfig def main(): - portconfig = sc2.portconfig.Portconfig() + portconfig: Portconfig = Portconfig() print(portconfig.as_json) player_config = [Bot(Race.Zerg, ZergRushBot()), Bot(Race.Zerg, None)] diff --git a/examples/protoss/cannon_rush.py b/examples/protoss/cannon_rush.py index 2d287202..abe691ac 100644 --- a/examples/protoss/cannon_rush.py +++ b/examples/protoss/cannon_rush.py @@ -9,7 +9,7 @@ class CannonRushBot(BotAI): - async def on_step(self, iteration): + async def on_step(self, iteration: int): if iteration == 0: await self.chat_send("(probe)(pylon)(cannon)(cannon)(gg)") diff --git a/examples/protoss/find_adept_shades.py b/examples/protoss/find_adept_shades.py index 8b136cfc..d10896d3 100644 --- a/examples/protoss/find_adept_shades.py +++ b/examples/protoss/find_adept_shades.py @@ -13,7 +13,7 @@ class FindAdeptShadesBot(BotAI): def __init__(self): self.shaded = False - self.shades_mapping = {} + self.shades_mapping: dict[int, int] = {} async def on_start(self): self.client.game_step = 2 diff --git a/examples/protoss/threebase_voidray.py b/examples/protoss/threebase_voidray.py index 2030a28f..314f6696 100644 --- a/examples/protoss/threebase_voidray.py +++ b/examples/protoss/threebase_voidray.py @@ -9,7 +9,7 @@ class ThreebaseVoidrayBot(BotAI): - async def on_step(self, iteration): + async def on_step(self, iteration: int): target_base_count = 3 target_stargate_count = 3 diff --git a/examples/protoss/warpgate_push.py b/examples/protoss/warpgate_push.py index 7058ea43..f1b36079 100644 --- a/examples/protoss/warpgate_push.py +++ b/examples/protoss/warpgate_push.py @@ -9,6 +9,7 @@ from sc2.ids.upgrade_id import UpgradeId from sc2.main import run_game from sc2.player import Bot, Computer +from sc2.unit import Unit class WarpGateBot(BotAI): @@ -16,11 +17,11 @@ def __init__(self): # Initialize inherited class self.proxy_built = False - async def warp_new_units(self, proxy): + async def warp_new_units(self, proxy: Unit): for warpgate in self.structures(UnitTypeId.WARPGATE).ready: - abilities = await self.get_available_abilities(warpgate) + abilities = await self.get_available_abilities([warpgate]) # all the units have the same cooldown anyway so let's just look at ZEALOT - if AbilityId.WARPGATETRAIN_STALKER in abilities: + if AbilityId.WARPGATETRAIN_STALKER in abilities[0]: pos = proxy.position.to2.random_on_distance(4) placement = await self.find_placement(AbilityId.WARPGATETRAIN_STALKER, pos, placement_step=1) if placement is None: @@ -29,7 +30,7 @@ async def warp_new_units(self, proxy): return warpgate.warp_in(UnitTypeId.STALKER, placement) - async def on_step(self, iteration): + async def on_step(self, iteration: int): await self.distribute_workers() if not self.townhalls.ready: diff --git a/examples/simulate_fight_scenario.py b/examples/simulate_fight_scenario.py index b3d66511..2bee4390 100644 --- a/examples/simulate_fight_scenario.py +++ b/examples/simulate_fight_scenario.py @@ -15,7 +15,7 @@ class FightBot(BotAI): def __init__(self): super().__init__() - self.enemy_location: Point2 = None + self.enemy_location: Point2 | None = None self.fight_started = False async def on_start(self): @@ -23,7 +23,7 @@ async def on_start(self): await self.client.debug_show_map() await self.client.debug_control_enemy() - async def on_step(self, iteration): + async def on_step(self, iteration: int): # Wait till control retrieved, destroy all starting units, recreate the world if iteration > 0 and self.enemy_units and not self.enemy_location: await self.reset_arena() diff --git a/examples/terran/cyclone_push.py b/examples/terran/cyclone_push.py index cf04c91a..f66a18bb 100644 --- a/examples/terran/cyclone_push.py +++ b/examples/terran/cyclone_push.py @@ -28,7 +28,7 @@ def select_target(self) -> Point2: # Pick a random mineral field on the map return self.mineral_field.random.position - async def on_step(self, iteration): + async def on_step(self, iteration: int): CCs: Units = self.townhalls(UnitTypeId.COMMANDCENTER) # If no command center exists, attack-move with all workers and cyclones if not CCs: @@ -87,7 +87,7 @@ async def on_step(self, iteration): if self.gas_buildings.filter(lambda unit: unit.distance_to(vg) < 1): continue # Select a worker closest to the vespene geysir - worker: Unit = self.select_build_worker(vg) + worker: Unit | None = self.select_build_worker(vg) # Worker can be none in cases where all workers are dead # or 'select_build_worker' function only selects from workers which carry no minerals if worker is None: @@ -112,9 +112,9 @@ async def on_step(self, iteration): # Saturate gas for refinery in self.gas_buildings: if refinery.assigned_harvesters < refinery.ideal_harvesters: - worker: Units = self.workers.closer_than(10, refinery) - if worker: - worker.random.gather(refinery) + workers: Units = self.workers.closer_than(10, refinery) + if workers: + workers.random.gather(refinery) for scv in self.workers.idle: scv.gather(self.mineral_field.closest_to(cc)) diff --git a/examples/terran/mass_reaper.py b/examples/terran/mass_reaper.py index 01aba5dd..4ca3f1aa 100644 --- a/examples/terran/mass_reaper.py +++ b/examples/terran/mass_reaper.py @@ -25,7 +25,7 @@ def __init__(self): # Select distance calculation method 0, which is the pure python distance calculation without caching or indexing, using math.hypot(), for more info see bot_ai_internal.py _distances_override_functions() function self.distance_calculation_method = 3 - async def on_step(self, iteration): + async def on_step(self, iteration: int): # Benchmark and print duration time of the on_step method based on "self.distance_calculation_method" value # logger.info(self.time_formatted, self.supply_used, self.step_time[1]) """ @@ -45,7 +45,9 @@ async def on_step(self, iteration): # If workers were found if workers: worker: Unit = workers.furthest_to(workers.center) - location: Point2 = await self.find_placement(UnitTypeId.SUPPLYDEPOT, worker.position, placement_step=3) + location: Point2 | None = await self.find_placement( + UnitTypeId.SUPPLYDEPOT, worker.position, placement_step=3 + ) # If a placement location was found if location: # Order worker to build exactly on that location @@ -72,13 +74,13 @@ async def on_step(self, iteration): and self.can_afford(UnitTypeId.COMMANDCENTER) ): # get_next_expansion returns the position of the next possible expansion location where you can place a command center - location: Point2 = await self.get_next_expansion() + location: Point2 | None = await self.get_next_expansion() if location: # Now we "select" (or choose) the nearest worker to that found location - worker: Unit = self.select_build_worker(location) - if worker and self.can_afford(UnitTypeId.COMMANDCENTER): + worker2: Unit | None = self.select_build_worker(location) + if worker2 and self.can_afford(UnitTypeId.COMMANDCENTER): # The worker will be commanded to build the command center - worker.build(UnitTypeId.COMMANDCENTER, location) + worker2.build(UnitTypeId.COMMANDCENTER, location) # Build up to 4 barracks if we can afford them # Check if we have a supply depot (tech requirement) before trying to make barracks @@ -97,7 +99,7 @@ async def on_step(self, iteration): ): # need to check if townhalls.amount > 0 because placement is based on townhall location worker: Unit = workers.furthest_to(workers.center) # I chose placement_step 4 here so there will be gaps between barracks hopefully - location: Point2 = await self.find_placement( + location: Point2 | None = await self.find_placement( UnitTypeId.BARRACKS, self.townhalls.random.position, placement_step=4 ) if location: @@ -168,7 +170,7 @@ async def on_step(self, iteration): retreat_points: set[Point2] = {x for x in retreat_points if self.in_pathing_grid(x)} if retreat_points: closest_enemy: Unit = enemy_threats_close.closest_to(r) - retreat_point: Unit = closest_enemy.position.furthest(retreat_points) + retreat_point: Point2 = closest_enemy.position.furthest(retreat_points) r.move(retreat_point) continue # Continue for loop, dont execute any of the following @@ -259,13 +261,13 @@ async def on_step(self, iteration): # Stolen and modified from position.py @staticmethod - def neighbors4(position, distance=1) -> set[Point2]: + def neighbors4(position: Point2, distance: float = 1) -> set[Point2]: p = position d = distance return {Point2((p.x - d, p.y)), Point2((p.x + d, p.y)), Point2((p.x, p.y - d)), Point2((p.x, p.y + d))} # Stolen and modified from position.py - def neighbors8(self, position, distance=1) -> set[Point2]: + def neighbors8(self, position: Point2, distance: float = 1) -> set[Point2]: p = position d = distance return self.neighbors4(position, distance) | { diff --git a/examples/terran/onebase_battlecruiser.py b/examples/terran/onebase_battlecruiser.py index 1173af5e..47cd5f62 100644 --- a/examples/terran/onebase_battlecruiser.py +++ b/examples/terran/onebase_battlecruiser.py @@ -29,7 +29,7 @@ def select_target(self) -> tuple[Point2, bool]: return self.mineral_field.random.position, False - async def on_step(self, iteration): + async def on_step(self, iteration: int): ccs: Units = self.townhalls # If we no longer have townhalls, attack with all workers if not ccs: @@ -85,7 +85,7 @@ async def on_step(self, iteration): if self.gas_buildings.filter(lambda unit: unit.distance_to(vg) < 1): break - worker: Unit = self.select_build_worker(vg.position) + worker: Unit | None = self.select_build_worker(vg.position) if worker is None: break @@ -172,9 +172,9 @@ def starport_land_positions(sp_position: Point2) -> list[Point2]: # Saturate refineries for refinery in self.gas_buildings: if refinery.assigned_harvesters < refinery.ideal_harvesters: - worker: Units = self.workers.closer_than(10, refinery) - if worker: - worker.random.gather(refinery) + workers: Units = self.workers.closer_than(10, refinery) + if workers: + workers.random.gather(refinery) # Send workers back to mine if they are idle for scv in self.workers.idle: diff --git a/examples/terran/proxy_rax.py b/examples/terran/proxy_rax.py index 0ce8d789..5912e461 100644 --- a/examples/terran/proxy_rax.py +++ b/examples/terran/proxy_rax.py @@ -13,7 +13,7 @@ class ProxyRaxBot(BotAI): async def on_start(self): self.client.game_step = 2 - async def on_step(self, iteration): + async def on_step(self, iteration: int): # If we don't have a townhall anymore, send all units to attack ccs: Units = self.townhalls(UnitTypeId.COMMANDCENTER) if not ccs: diff --git a/examples/terran/ramp_wall.py b/examples/terran/ramp_wall.py index 244bce99..0aa12fea 100644 --- a/examples/terran/ramp_wall.py +++ b/examples/terran/ramp_wall.py @@ -19,7 +19,7 @@ class RampWallBot(BotAI): def __init__(self): self.unit_command_uses_self_do = False - async def on_step(self, iteration): + async def on_step(self, iteration: int): ccs: Units = self.townhalls(UnitTypeId.COMMANDCENTER) if not ccs: return @@ -70,11 +70,11 @@ async def on_step(self, iteration): # Draw if two selected units are facing each other - green if this guy is facing the other, red if he is not self.draw_facing_units() - depot_placement_positions: frozenset[Point2] = self.main_base_ramp.corner_depots + depot_placement_positions: set[Point2] = self.main_base_ramp.corner_depots # Uncomment the following if you want to build 3 supply depots in the wall instead of a barracks in the middle + 2 depots in the corner # depot_placement_positions = self.main_base_ramp.corner_depots | {self.main_base_ramp.depot_in_middle} - barracks_placement_position: Point2 = self.main_base_ramp.barracks_correct_placement + barracks_placement_position: Point2 | None = self.main_base_ramp.barracks_correct_placement # If you prefer to have the barracks in the middle without room for addons, use the following instead # barracks_placement_position = self.main_base_ramp.barracks_in_middle diff --git a/examples/too_slow_bot.py b/examples/too_slow_bot.py index 28d32c0e..6e8b9baf 100644 --- a/examples/too_slow_bot.py +++ b/examples/too_slow_bot.py @@ -9,7 +9,7 @@ class SlowBot(ProxyRaxBot): - async def on_step(self, iteration): + async def on_step(self, iteration: int): await asyncio.sleep(random.random()) await super().on_step(iteration) diff --git a/examples/worker_rush.py b/examples/worker_rush.py index 686c7256..3f57e7d2 100644 --- a/examples/worker_rush.py +++ b/examples/worker_rush.py @@ -6,7 +6,7 @@ class WorkerRushBot(BotAI): - async def on_step(self, iteration): + async def on_step(self, iteration: int): if iteration == 0: for worker in self.workers: worker.attack(self.enemy_start_locations[0]) diff --git a/examples/zerg/banes_banes_banes.py b/examples/zerg/banes_banes_banes.py index 85a00c70..c5216977 100644 --- a/examples/zerg/banes_banes_banes.py +++ b/examples/zerg/banes_banes_banes.py @@ -23,7 +23,7 @@ def select_target(self) -> Point2: return random.choice(self.enemy_structures).position return self.enemy_start_locations[0] - async def on_step(self, iteration): + async def on_step(self, iteration: int): larvae: Units = self.larva lings: Units = self.units(UnitTypeId.ZERGLING) # Send all idle banes to enemy diff --git a/examples/zerg/expand_everywhere.py b/examples/zerg/expand_everywhere.py index 552ae0f9..87186ac2 100644 --- a/examples/zerg/expand_everywhere.py +++ b/examples/zerg/expand_everywhere.py @@ -16,7 +16,7 @@ async def on_start(self): self.client.game_step = 50 await self.client.debug_show_map() - async def on_step(self, iteration): + async def on_step(self, iteration: int): # Build overlords if about to be supply blocked if ( self.supply_left < 2 diff --git a/examples/zerg/hydralisk_push.py b/examples/zerg/hydralisk_push.py index 6e6d17e2..34f80003 100644 --- a/examples/zerg/hydralisk_push.py +++ b/examples/zerg/hydralisk_push.py @@ -19,7 +19,7 @@ def select_target(self) -> Point2: return random.choice(self.enemy_structures).position return self.enemy_start_locations[0] - async def on_step(self, iteration): + async def on_step(self, iteration: int): larvae: Units = self.larva forces: Units = self.units.of_type({UnitTypeId.ZERGLING, UnitTypeId.HYDRALISK}) diff --git a/examples/zerg/onebase_broodlord.py b/examples/zerg/onebase_broodlord.py index 72d75bca..5db57a51 100644 --- a/examples/zerg/onebase_broodlord.py +++ b/examples/zerg/onebase_broodlord.py @@ -19,7 +19,7 @@ def select_target(self) -> Point2: return random.choice(self.enemy_structures).position return self.enemy_start_locations[0] - async def on_step(self, iteration): + async def on_step(self, iteration: int): larvae: Units = self.larva forces: Units = self.units.of_type({UnitTypeId.ZERGLING, UnitTypeId.CORRUPTOR, UnitTypeId.BROODLORD}) diff --git a/examples/zerg/worker_split.py b/examples/zerg/worker_split.py index 3edec5bb..689bdd90 100644 --- a/examples/zerg/worker_split.py +++ b/examples/zerg/worker_split.py @@ -30,7 +30,7 @@ async def on_before_start(self): async def on_start(self): """This function is run after the expansion locations and ramps are calculated.""" - async def on_step(self, iteration): + async def on_step(self, iteration: int): if iteration % 10 == 0: await asyncio.sleep(3) # In realtime=False, this should print "8*x" and "x" if diff --git a/examples/zerg/zerg_rush.py b/examples/zerg/zerg_rush.py index 93139434..15d0df50 100644 --- a/examples/zerg/zerg_rush.py +++ b/examples/zerg/zerg_rush.py @@ -22,7 +22,7 @@ def __init__(self): async def on_start(self): self.client.game_step = 2 - async def on_step(self, iteration): + async def on_step(self, iteration: int): if iteration == 0: await self.chat_send("(glhf)") @@ -38,11 +38,11 @@ async def on_step(self, iteration): hatch: Unit = self.townhalls[0] # Pick a target location - target: Point2 = self.enemy_structures.not_flying.random_or(self.enemy_start_locations[0]).position + target_pos: Point2 = self.enemy_structures.not_flying.random_or(self.enemy_start_locations[0]).position # Give all zerglings an attack command for zergling in self.units(UnitTypeId.ZERGLING): - zergling.attack(target) + zergling.attack(target=target_pos) # Inject hatchery if queen has more than 25 energy for queen in self.units(UnitTypeId.QUEEN): From 81334b7fb6c0cfe1f552958a9b276abe8714ebc8 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 02:14:00 +0100 Subject: [PATCH 21/34] Improve typing in library files --- sc2/action.py | 1 - sc2/bot_ai.py | 2 +- sc2/bot_ai_internal.py | 16 +++-- sc2/cache.py | 6 +- sc2/client.py | 112 ++++++++++++++++++++++-------- sc2/controller.py | 11 ++- sc2/game_info.py | 6 +- sc2/game_state.py | 2 +- sc2/main.py | 104 +++++++++++++++++++-------- sc2/pixel_map.py | 3 +- sc2/player.py | 8 ++- sc2/portconfig.py | 6 +- sc2/position.py | 12 ++-- sc2/protocol.py | 6 +- sc2/proxy.py | 1 - sc2/renderer.py | 13 ++-- sc2/sc2process.py | 2 - test/generate_pickle_files_bot.py | 1 - test/test_pickled_data.py | 1 - 19 files changed, 217 insertions(+), 96 deletions(-) diff --git a/sc2/action.py b/sc2/action.py index d0e534ba..b43124ae 100644 --- a/sc2/action.py +++ b/sc2/action.py @@ -3,7 +3,6 @@ from itertools import groupby from typing import TYPE_CHECKING -# pyre-ignore[21] from s2clientprotocol import raw_pb2 as raw_pb from sc2.position import Point2 from sc2.unit import Unit diff --git a/sc2/bot_ai.py b/sc2/bot_ai.py index a4bcfb86..d98e72a0 100644 --- a/sc2/bot_ai.py +++ b/sc2/bot_ai.py @@ -1187,7 +1187,7 @@ async def chat_send(self, message: str, team_only: bool = False) -> None: assert isinstance(message, str), f"{message} is not a string" await self.client.chat_send(message, team_only) - def in_map_bounds(self, pos: Point2 | tuple | list) -> bool: + def in_map_bounds(self, pos: Point2 | tuple[float, float] | list[float]) -> bool: """Tests if a 2 dimensional point is within the map boundaries of the pixelmaps. :param pos:""" diff --git a/sc2/bot_ai_internal.py b/sc2/bot_ai_internal.py index 061ef70f..e7439212 100644 --- a/sc2/bot_ai_internal.py +++ b/sc2/bot_ai_internal.py @@ -14,7 +14,6 @@ import numpy as np from loguru import logger -# pyre-ignore[21] from s2clientprotocol import sc2api_pb2 as sc_pb from sc2.cache import property_cache_once_per_frame from sc2.constants import ( @@ -41,7 +40,6 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - # pyre-ignore[21] from scipy.spatial.distance import cdist, pdist if TYPE_CHECKING: @@ -581,7 +579,7 @@ async def _do_actions(self, actions: list[UnitCommand], prevent_double: bool = T @final @staticmethod - def prevent_double_actions(action) -> bool: + def prevent_double_actions(action: UnitCommand) -> bool: """ :param action: """ @@ -608,7 +606,13 @@ def prevent_double_actions(action) -> bool: @final def _prepare_start( - self, client, player_id: int, game_info, game_data, realtime: bool = False, base_build: int = -1 + self, + client: Client, + player_id: int, + game_info: GameInfo, + game_data: GameData, + realtime: bool = False, + base_build: int = -1, ) -> None: """ Ran until game start to set game and player data. @@ -645,13 +649,13 @@ def _prepare_first_step(self) -> None: self._time_before_step: float = time.perf_counter() @final - def _prepare_step(self, state, proto_game_info) -> None: + def _prepare_step(self, state: GameState, proto_game_info: sc_pb.Response) -> None: """ :param state: :param proto_game_info: """ # Set attributes from new state before on_step.""" - self.state: GameState = state # See game_state.py + self.state = state # See game_state.py # update pathing grid, which unfortunately is in GameInfo instead of GameState self.game_info.pathing_grid = PixelMap(proto_game_info.game_info.start_raw.pathing_grid, in_bits=True) # Required for events, needs to be before self.units are initialized so the old units are stored diff --git a/sc2/cache.py b/sc2/cache.py index d3e9090d..b1682fc5 100644 --- a/sc2/cache.py +++ b/sc2/cache.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Hashable -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar if TYPE_CHECKING: from sc2.bot_ai import BotAI @@ -9,7 +9,7 @@ T = TypeVar("T") -class CacheDict(dict): +class CacheDict(dict[Hashable, Any]): def retrieve_and_set(self, key: Hashable, func: Callable[[], T]) -> T: """Either return the value at a certain key, or set the return value of a function to that key, then return that value.""" @@ -29,7 +29,7 @@ class property_cache_once_per_frame(property): # noqa: N801 Copied and modified from https://tedboy.github.io/flask/_modules/werkzeug/utils.html#cached_property #""" - def __init__(self, func: Callable[[BotAI], T], name=None) -> None: + def __init__(self, func: Callable[[BotAI], T], name: str | None = None) -> None: self.__name__ = name or func.__name__ self.__frame__ = f"__frame__{self.__name__}" self.func = func diff --git a/sc2/client.py b/sc2/client.py index cde42374..6dfb8703 100644 --- a/sc2/client.py +++ b/sc2/client.py @@ -3,10 +3,11 @@ from collections.abc import Iterable from pathlib import Path +from typing import Any +from aiohttp import ClientWebSocketResponse from loguru import logger -# pyre-ignore[21] from s2clientprotocol import debug_pb2 as debug_pb from s2clientprotocol import query_pb2 as query_pb from s2clientprotocol import raw_pb2 as raw_pb @@ -18,6 +19,7 @@ from sc2.game_info import GameInfo from sc2.ids.ability_id import AbilityId from sc2.ids.unit_typeid import UnitTypeId +from sc2.portconfig import Portconfig from sc2.position import Point2, Point3 from sc2.protocol import ConnectionAlreadyClosedError, Protocol, ProtocolError from sc2.renderer import Renderer @@ -27,7 +29,7 @@ class Client(Protocol): - def __init__(self, ws, save_replay_path: str = None) -> None: + def __init__(self, ws: ClientWebSocketResponse, save_replay_path: str | None = None) -> None: """ :param ws: """ @@ -52,7 +54,14 @@ def __init__(self, ws, save_replay_path: str = None) -> None: def in_game(self) -> bool: return self._status in {Status.in_game, Status.in_replay} - async def join_game(self, name=None, race=None, observed_player_id=None, portconfig=None, rgb_render_config=None): + async def join_game( + self, + name: str | None = None, + race: Race | None = None, + observed_player_id: int | None = None, + portconfig: Portconfig | None = None, + rgb_render_config: dict[str, Any] | None = None, + ): ifopts = sc_pb.InterfaceOptions( raw=True, score=True, @@ -121,14 +130,14 @@ async def leave(self) -> None: if is_resign: raise - async def save_replay(self, path) -> None: + async def save_replay(self, path: str) -> None: logger.debug("Requesting replay from server") result = await self._execute(save_replay=sc_pb.RequestSaveReplay()) with Path(path).open("wb") as f: f.write(result.save_replay.data) logger.info(f"Saved replay to {path}") - async def observation(self, game_loop: int = None): + async def observation(self, game_loop: int | None = None): if game_loop is not None: result = await self._execute(observation=sc_pb.RequestObservation(game_loop=game_loop)) else: @@ -152,13 +161,13 @@ async def observation(self, game_loop: int = None): return result - async def step(self, step_size: int = None): + async def step(self, step_size: int | None = None): """EXPERIMENTAL: Change self._client.game_step during the step function to increase or decrease steps per second""" step_size = step_size or self.game_step return await self._execute(step=sc_pb.RequestStep(count=step_size)) async def get_game_data(self) -> GameData: - result: sc_pb.ResponseData = await self._execute( + result: sc_pb.Response = await self._execute( data=sc_pb.RequestData(ability_id=True, unit_type_id=True, upgrade_id=True, buff_id=True, effect_id=True) ) return GameData(result.data) @@ -470,8 +479,8 @@ def debug_text_simple(self, text: str) -> None: def debug_text_screen( self, text: str, - pos: Point2 | Point3 | tuple | list, - color: tuple | list | Point3 = None, + pos: Point2 | Point3 | tuple[float, float] | list[float], + color: tuple[float, float] | list[float] | Point3 | None = None, size: int = 8, ) -> None: """ @@ -491,14 +500,18 @@ def debug_text_screen( def debug_text_2d( self, text: str, - pos: Point2 | Point3 | tuple | list, - color: tuple | list | Point3 = None, + pos: Point2 | Point3 | tuple[float, float] | list[float], + color: tuple[float, float] | list[float] | Point3 | None = None, size: int = 8, ): return self.debug_text_screen(text, pos, color, size) def debug_text_world( - self, text: str, pos: Unit | Point3, color: tuple | list | Point3 = None, size: int = 8 + self, + text: str, + pos: Unit | Point3, + color: tuple[float, float] | list[float] | Point3 | None = None, + size: int = 8, ) -> None: """ Draws a text at Point3 position in the game world. @@ -514,10 +527,21 @@ def debug_text_world( assert isinstance(pos, Point3) self._debug_texts.append(DrawItemWorldText(text=text, color=color, start_point=pos, font_size=size)) - def debug_text_3d(self, text: str, pos: Unit | Point3, color: tuple | list | Point3 = None, size: int = 8): + def debug_text_3d( + self, + text: str, + pos: Unit | Point3, + color: tuple[float, float] | list[float] | Point3 | None = None, + size: int = 8, + ): return self.debug_text_world(text, pos, color, size) - def debug_line_out(self, p0: Unit | Point3, p1: Unit | Point3, color: tuple | list | Point3 = None) -> None: + def debug_line_out( + self, + p0: Unit | Point3, + p1: Unit | Point3, + color: tuple[float, float] | list[float] | Point3 | None = None, + ) -> None: """ Draws a line from p0 to p1. @@ -537,7 +561,7 @@ def debug_box_out( self, p_min: Unit | Point3, p_max: Unit | Point3, - color: tuple | list | Point3 = None, + color: tuple[float, float] | list[float] | Point3 | None = None, ) -> None: """ Draws a box with p_min and p_max as corners of the box. @@ -558,7 +582,7 @@ def debug_box2_out( self, pos: Unit | Point3, half_vertex_length: float = 0.25, - color: tuple | list | Point3 = None, + color: tuple[float, float] | list[float] | Point3 | None = None, ) -> None: """ Draws a box center at a position 'pos', with box side lengths (vertices) of two times 'half_vertex_length'. @@ -574,7 +598,12 @@ def debug_box2_out( p1 = pos + Point3((half_vertex_length, half_vertex_length, half_vertex_length)) self._debug_boxes.append(DrawItemBox(start_point=p0, end_point=p1, color=color)) - def debug_sphere_out(self, p: Unit | Point3, r: float, color: tuple | list | Point3 = None) -> None: + def debug_sphere_out( + self, + p: Unit | Point3, + r: float, + color: tuple[float, float] | list[float] | Point3 | None = None, + ) -> None: """ Draws a sphere at point p with radius r. @@ -752,12 +781,12 @@ async def quick_load(self) -> None: class DrawItem: @staticmethod - def to_debug_color(color: tuple | Point3): + def to_debug_color(color: tuple[float, float] | list[float] | Point3 | None = None) -> debug_pb.Color: """Helper function for color conversion""" if color is None: return debug_pb.Color(r=255, g=255, b=255) # Need to check if not of type Point3 because Point3 inherits from tuple - if isinstance(color, (tuple, list)) and not isinstance(color, Point3) and len(color) == 3: + if isinstance(color, (tuple, list)) or isinstance(color, Point3) and len(color) == 3: return debug_pb.Color(r=color[0], g=color[1], b=color[2]) # In case color is of type Point3 r = getattr(color, "r", getattr(color, "x", 255)) @@ -773,11 +802,17 @@ def to_debug_color(color: tuple | Point3): class DrawItemScreenText(DrawItem): - def __init__(self, start_point: Point2 = None, color: Point3 = None, text: str = "", font_size: int = 8) -> None: - self._start_point: Point2 = start_point - self._color: Point3 = color - self._text: str = text - self._font_size: int = font_size + def __init__( + self, + start_point: Point2, + color: tuple[float, float] | list[float] | Point3 | None = None, + text: str = "", + font_size: int = 8, + ) -> None: + self._start_point = start_point + self._color = color + self._text = text + self._font_size = font_size def to_proto(self): return debug_pb.DebugText( @@ -793,7 +828,13 @@ def __hash__(self) -> int: class DrawItemWorldText(DrawItem): - def __init__(self, start_point: Point3 = None, color: Point3 = None, text: str = "", font_size: int = 8) -> None: + def __init__( + self, + start_point: Point3 = None, + color: tuple[float, float] | list[float] | Point3 | None = None, + text: str = "", + font_size: int = 8, + ) -> None: self._start_point: Point3 = start_point self._color: Point3 = color self._text: str = text @@ -813,7 +854,12 @@ def __hash__(self) -> int: class DrawItemLine(DrawItem): - def __init__(self, start_point: Point3 = None, end_point: Point3 = None, color: Point3 = None) -> None: + def __init__( + self, + start_point: Point3 = None, + end_point: Point3 = None, + color: tuple[float, float] | list[float] | Point3 | None = None, + ) -> None: self._start_point: Point3 = start_point self._end_point: Point3 = end_point self._color: Point3 = color @@ -829,7 +875,12 @@ def __hash__(self) -> int: class DrawItemBox(DrawItem): - def __init__(self, start_point: Point3 = None, end_point: Point3 = None, color: Point3 = None) -> None: + def __init__( + self, + start_point: Point3 = None, + end_point: Point3 = None, + color: tuple[float, float] | list[float] | Point3 | None = None, + ) -> None: self._start_point: Point3 = start_point self._end_point: Point3 = end_point self._color: Point3 = color @@ -846,7 +897,12 @@ def __hash__(self) -> int: class DrawItemSphere(DrawItem): - def __init__(self, start_point: Point3 = None, radius: float = None, color: Point3 = None) -> None: + def __init__( + self, + start_point: Point3 = None, + radius: float = None, + color: tuple[float, float] | list[float] | Point3 | None = None, + ) -> None: self._start_point: Point3 = start_point self._radius: float = radius self._color: Point3 = color diff --git a/sc2/controller.py b/sc2/controller.py index 2e480330..e068aa3f 100644 --- a/sc2/controller.py +++ b/sc2/controller.py @@ -1,17 +1,22 @@ +from __future__ import annotations + import platform from pathlib import Path +from typing import TYPE_CHECKING +from aiohttp import ClientWebSocketResponse from loguru import logger -# pyre-ignore[21] from s2clientprotocol import sc2api_pb2 as sc_pb - from sc2.player import Computer from sc2.protocol import Protocol +if TYPE_CHECKING: + from sc2.sc2process import SC2Process + class Controller(Protocol): - def __init__(self, ws, process) -> None: + def __init__(self, ws: ClientWebSocketResponse, process: SC2Process) -> None: super().__init__(ws) self._process = process diff --git a/sc2/game_info.py b/sc2/game_info.py index 4e6f1243..aab025d5 100644 --- a/sc2/game_info.py +++ b/sc2/game_info.py @@ -123,10 +123,10 @@ def depot_in_middle(self) -> Point2 | None: raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") @cached_property - def corner_depots(self) -> frozenset[Point2]: + def corner_depots(self) -> set[Point2]: """Finds the 2 depot positions on the outside""" if not self.upper2_for_ramp_wall: - return frozenset() + return set() if len(self.upper2_for_ramp_wall) == 2: points = set(self.upper2_for_ramp_wall) p1 = points.pop().offset((self.x_offset, self.y_offset)) @@ -134,7 +134,7 @@ def corner_depots(self) -> frozenset[Point2]: center = p1.towards(p2, p1.distance_to_point2(p2) / 2) depot_position = self.depot_in_middle if depot_position is None: - return frozenset() + return set() # Offset from middle depot to corner depots is (2, 1) intersects = center.circle_intersection(depot_position, 5**0.5) return intersects diff --git a/sc2/game_state.py b/sc2/game_state.py index e4b7d17b..7ff4bb89 100644 --- a/sc2/game_state.py +++ b/sc2/game_state.py @@ -83,7 +83,7 @@ class Common: "larva_count", ] - def __init__(self, proto) -> None: + def __init__(self, proto: sc2api_pb2.PlayerCommon) -> None: self._proto = proto def __getattr__(self, attr) -> int: diff --git a/sc2/main.py b/sc2/main.py index fd86c6a7..88e82344 100644 --- a/sc2/main.py +++ b/sc2/main.py @@ -10,19 +10,21 @@ from dataclasses import dataclass from io import BytesIO from pathlib import Path +from typing import Any import mpyq import portpicker from aiohttp import ClientSession, ClientWebSocketResponse from loguru import logger -from s2clientprotocol import sc2api_pb2 as sc_pb +from s2clientprotocol import sc2api_pb2 as sc_pb from sc2.bot_ai import BotAI from sc2.client import Client from sc2.controller import Controller from sc2.data import CreateGameError, Result, Status from sc2.game_state import GameState from sc2.maps import Map +from sc2.observer_ai import ObserverAI from sc2.player import AbstractPlayer, Bot, BotProcess, Human from sc2.portconfig import Portconfig from sc2.protocol import ConnectionAlreadyClosedError, ProtocolError @@ -71,7 +73,7 @@ def needed_sc2_count(self) -> int: return sum(player.needs_sc2 for player in self.players) @property - def host_game_kwargs(self) -> dict: + def host_game_kwargs(self) -> dict[str, Any]: return { "map_settings": self.map_sc2, "players": self.players, @@ -203,7 +205,12 @@ async def run_bot_iteration(iteration: int): async def _play_game( - player: AbstractPlayer, client: Client, realtime, portconfig, game_time_limit=None, rgb_render_config=None + player: AbstractPlayer, + client: Client, + realtime: bool, + portconfig: Portconfig, + game_time_limit: int | None = None, + rgb_render_config: dict[str, Any] | None = None, ) -> Result: assert isinstance(realtime, bool), repr(realtime) @@ -328,16 +335,16 @@ async def _setup_host_game( async def _host_game( - map_settings, - players, + map_settings: Map, + players: list[AbstractPlayer], realtime: bool = False, - portconfig=None, - save_replay_as=None, - game_time_limit=None, - rgb_render_config=None, - random_seed=None, - sc2_version=None, - disable_fog=None, + portconfig: Portconfig | None = None, + save_replay_as: str | None = None, + game_time_limit: int | None = None, + rgb_render_config: dict[str, Any] | None = None, + random_seed: int | None = None, + sc2_version: str | None = None, + disable_fog: bool = False, ): assert players, "Can't create a game without players" @@ -410,19 +417,19 @@ def _host_game_iter(*args, **kwargs): async def _join_game( - players, - realtime, - portconfig, - save_replay_as=None, - game_time_limit=None, - sc2_version=None, + players: list[AbstractPlayer], + realtime: bool = False, + portconfig: Portconfig | None = None, + save_replay_as: str | None = None, + game_time_limit: int | None = None, + sc2_version: str | None = None, ): async with SC2Process(fullscreen=players[1].fullscreen, sc2_version=sc2_version) as server: await server.ping() client = Client(server._ws) # Bot can decide if it wants to launch with 'raw_affects_selection=True' - if not isinstance(players[1], Human) and getattr(players[1].ai, "raw_affects_selection", None) is not None: + if isinstance(players[1], Bot) and getattr(players[1].ai, "raw_affects_selection", None) is not None: client.raw_affects_selection = players[1].ai.raw_affects_selection result = await _play_game(players[1], client, realtime, portconfig, game_time_limit) @@ -442,7 +449,9 @@ async def _setup_replay(server, replay_path, realtime, observed_id): return Client(server._ws) -async def _host_replay(replay_path, ai, realtime, _portconfig, base_build, data_version, observed_id): +async def _host_replay( + replay_path, ai: ObserverAI, realtime: bool, _portconfig: Portconfig, base_build, data_version, observed_id +): async with SC2Process(fullscreen=False, base_build=base_build, data_hash=data_version) as server: client = await _setup_replay(server, replay_path, realtime, observed_id) result = await _play_replay(client, ai, realtime) @@ -461,21 +470,47 @@ def get_replay_version(replay_path: str | Path) -> tuple[str, str]: # TODO Deprecate run_game function in favor of run_multiple_games -def run_game(map_settings, players, **kwargs) -> Result | list[Result | None]: +def run_game( + map_settings: Map, + players: list[AbstractPlayer], + realtime: bool, + portconfig: Portconfig | None = None, + save_replay_as: str | None = None, + game_time_limit: int | None = None, + rgb_render_config: dict[str, Any] | None = None, + random_seed: int | None = None, + sc2_version: str | None = None, + disable_fog: bool = False, +) -> Result | list[Result | None]: """ Returns a single Result enum if the game was against the built-in computer. Returns a list of two Result enums if the game was "Human vs Bot" or "Bot vs Bot". """ if sum(isinstance(p, (Human, Bot)) for p in players) > 1: - host_only_args = ["save_replay_as", "rgb_render_config", "random_seed", "disable_fog"] - join_kwargs = {k: v for k, v in kwargs.items() if k not in host_only_args} - portconfig = Portconfig() async def run_host_and_join(): return await asyncio.gather( - _host_game(map_settings, players, **kwargs, portconfig=portconfig), - _join_game(players, **join_kwargs, portconfig=portconfig), + _host_game( + map_settings, + players, + realtime=realtime, + portconfig=portconfig, + save_replay_as=save_replay_as, + game_time_limit=game_time_limit, + rgb_render_config=rgb_render_config, + random_seed=random_seed, + sc2_version=sc2_version, + disable_fog=disable_fog, + ), + _join_game( + players, + realtime=realtime, + portconfig=portconfig, + save_replay_as=save_replay_as, + game_time_limit=game_time_limit, + sc2_version=sc2_version, + ), return_exceptions=True, ) @@ -483,12 +518,25 @@ async def run_host_and_join(): assert isinstance(result, list) assert all(isinstance(r, Result) for r in result) else: - result: Result = asyncio.run(_host_game(map_settings, players, **kwargs)) + result: Result = asyncio.run( + _host_game( + map_settings, + players, + realtime=realtime, + portconfig=portconfig, + save_replay_as=save_replay_as, + game_time_limit=game_time_limit, + rgb_render_config=rgb_render_config, + random_seed=random_seed, + sc2_version=sc2_version, + disable_fog=disable_fog, + ) + ) assert isinstance(result, Result) return result -def run_replay(ai, replay_path: Path | str, realtime: bool = False, observed_id: int = 0): +def run_replay(ai: ObserverAI, replay_path: Path | str, realtime: bool = False, observed_id: int = 0): portconfig = Portconfig() assert Path(replay_path).is_file(), f"Replay does not exist at the given path: {replay_path}" assert Path(replay_path).is_absolute(), ( diff --git a/sc2/pixel_map.py b/sc2/pixel_map.py index 6871a516..c6925d80 100644 --- a/sc2/pixel_map.py +++ b/sc2/pixel_map.py @@ -5,11 +5,12 @@ import numpy as np +from s2clientprotocol.common_pb2 import ImageData from sc2.position import Point2 class PixelMap: - def __init__(self, proto, in_bits: bool = False) -> None: + def __init__(self, proto: ImageData, in_bits: bool = False) -> None: """ :param proto: :param in_bits: diff --git a/sc2/player.py b/sc2/player.py index 7e7d5255..bd1410a5 100644 --- a/sc2/player.py +++ b/sc2/player.py @@ -101,13 +101,13 @@ def __init__( p_type: PlayerType, requested_race: Race, difficulty: Difficulty | None = None, - actual_race: Race = None, + actual_race: Race | None = None, name: str | None = None, ai_build: AIBuild | None = None, ) -> None: super().__init__(p_type, requested_race, difficulty=difficulty, name=name, ai_build=ai_build) self.id: int = player_id - self.actual_race: Race = actual_race + self.actual_race: Race | None = actual_race @classmethod def from_proto(cls, proto: sc2api_pb2.PlayerInfo) -> Player: @@ -171,7 +171,9 @@ def __repr__(self) -> str: return f"Bot {self.name}({self.race.name} from {self.launch_list})" return f"Bot({self.race.name} from {self.launch_list})" - def cmd_line(self, sc2port: int | str, matchport: int | str, hostaddress: str, realtime: bool = False) -> list[str]: + def cmd_line( + self, sc2port: int | str, matchport: int | str | None, hostaddress: str, realtime: bool = False + ) -> list[str]: """ :param sc2port: the port that the launched sc2 instance listens to diff --git a/sc2/portconfig.py b/sc2/portconfig.py index 2e646faf..9041b90f 100644 --- a/sc2/portconfig.py +++ b/sc2/portconfig.py @@ -25,9 +25,11 @@ class Portconfig: E.g. for 1v1, there will be only 1 guest. For 2v2 (coming soonTM), there would be 3 guests. """ - def __init__(self, guests: int = 1, server_ports=None, player_ports=None) -> None: + def __init__( + self, guests: int = 1, server_ports: list[int] | None = None, player_ports: list[int] | None = None + ) -> None: self.shared = None - self._picked_ports = [] + self._picked_ports: list[int] = [] if server_ports: self.server = server_ports else: diff --git a/sc2/position.py b/sc2/position.py index 71f49856..32e2fca3 100644 --- a/sc2/position.py +++ b/sc2/position.py @@ -5,7 +5,7 @@ import math import random from collections.abc import Iterable -from typing import TYPE_CHECKING, Self, SupportsFloat, SupportsIndex +from typing import TYPE_CHECKING, Self, SupportsFloat, SupportsIndex, overload # pyre-fixme[21] from s2clientprotocol import common_pb2 as common_pb @@ -77,7 +77,11 @@ def distance_to_closest(self, ps: Units | Iterable[Point2]) -> float: closest_distance = distance return closest_distance - def furthest(self, ps: Units | Iterable[Point2]) -> Unit | Pointlike: + @overload + def furthest(self, ps: Units) -> Unit: ... + @overload + def furthest(self, ps: Iterable[Point2]) -> Point2: ... + def furthest(self, ps: Units | Iterable[Point2]) -> Unit | Point2: """This function assumes the 2d distance is meant :param ps: Units object, or iterable of Unit or Point2""" @@ -112,7 +116,7 @@ def unit_axes_towards(self, p: tuple[float, float]) -> Self: """ return self.__class__(_sign(b - a) for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0)) - def towards(self, p: Unit | Pointlike, distance: int | float = 1, limit: bool = False) -> Pointlike: + def towards(self, p: Unit | Pointlike, distance: int | float = 1, limit: bool = False) -> Self: """ :param p: @@ -199,7 +203,7 @@ def round(self, decimals: int) -> Point2: """Rounds each number in the tuple to the amount of given decimals.""" return Point2((round(self[0], decimals), round(self[1], decimals))) - def offset(self, p: Point2) -> Point2: + def offset(self, p: tuple[float, float]) -> Point2: return Point2((self[0] + p[0], self[1] + p[1])) def random_on_distance(self, distance: float | tuple[float, float] | list[float]) -> Point2: diff --git a/sc2/protocol.py b/sc2/protocol.py index 5577b08f..ce28608f 100644 --- a/sc2/protocol.py +++ b/sc2/protocol.py @@ -9,7 +9,6 @@ # pyre-fixme[21] from s2clientprotocol import sc2api_pb2 as sc_pb - from sc2.data import Status @@ -34,7 +33,7 @@ def __init__(self, ws: ClientWebSocketResponse) -> None: # pyre-fixme[11] self._status: Status | None = None - async def __request(self, request): + async def __request(self, request: sc_pb.Request) -> sc_pb.Response: logger.debug(f"Sending request: {request!r}") try: await self._ws.send_bytes(request.SerializeToString()) @@ -65,7 +64,8 @@ async def __request(self, request): logger.debug("Response received") return response - async def _execute(self, **kwargs): + # TODO Add types using @overload with various request args and response types + async def _execute(self, **kwargs) -> sc_pb.Response: assert len(kwargs) == 1, "Only one request allowed by the API" response = await self.__request(sc_pb.Request(**kwargs)) diff --git a/sc2/proxy.py b/sc2/proxy.py index f2690322..340570cd 100644 --- a/sc2/proxy.py +++ b/sc2/proxy.py @@ -15,7 +15,6 @@ # pyre-fixme[21] from s2clientprotocol import sc2api_pb2 as sc_pb - from sc2.controller import Controller from sc2.data import Result, Status from sc2.player import BotProcess diff --git a/sc2/renderer.py b/sc2/renderer.py index 4d9f94ff..17e3599e 100644 --- a/sc2/renderer.py +++ b/sc2/renderer.py @@ -1,13 +1,18 @@ +from __future__ import annotations + import datetime +from typing import TYPE_CHECKING -# pyre-ignore[21] from s2clientprotocol import score_pb2 as score_pb - +from s2clientprotocol.sc2api_pb2 import ResponseObservation from sc2.position import Point2 +if TYPE_CHECKING: + from sc2.client import Client + class Renderer: - def __init__(self, client, map_size, minimap_size) -> None: + def __init__(self, client: Client, map_size: tuple[float, float], minimap_size: tuple[float, float]) -> None: self._client = client self._window = None @@ -22,7 +27,7 @@ def __init__(self, client, map_size, minimap_size) -> None: self._text_score = None self._text_time = None - async def render(self, observation) -> None: + async def render(self, observation: ResponseObservation) -> None: render_data = observation.observation.render_data map_size = render_data.map.size diff --git a/sc2/sc2process.py b/sc2/sc2process.py index 3fa1777a..391d30b1 100644 --- a/sc2/sc2process.py +++ b/sc2/sc2process.py @@ -13,8 +13,6 @@ from typing import Any import aiohttp - -# pyre-ignore[21] import portpicker from aiohttp.client_ws import ClientWebSocketResponse from loguru import logger diff --git a/test/generate_pickle_files_bot.py b/test/generate_pickle_files_bot.py index ae8d46b8..3b4158ee 100644 --- a/test/generate_pickle_files_bot.py +++ b/test/generate_pickle_files_bot.py @@ -10,7 +10,6 @@ from loguru import logger -# pyre-ignore[21] from s2clientprotocol import sc2api_pb2 as sc_pb from sc2 import maps from sc2.bot_ai import BotAI diff --git a/test/test_pickled_data.py b/test/test_pickled_data.py index ca0e7767..8a6c5e69 100644 --- a/test/test_pickled_data.py +++ b/test/test_pickled_data.py @@ -20,7 +20,6 @@ from pathlib import Path from typing import Any -# pyre-ignore[21] from google.protobuf.internal import api_implementation from hypothesis import given, settings from hypothesis import strategies as st From 5b9ae8f1c964ccd21156767e1900d3b97527911d Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 02:25:32 +0100 Subject: [PATCH 22/34] Add overloads for position --- sc2/position.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/sc2/position.py b/sc2/position.py index 32e2fca3..97735b22 100644 --- a/sc2/position.py +++ b/sc2/position.py @@ -4,7 +4,7 @@ import itertools import math import random -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Self, SupportsFloat, SupportsIndex, overload # pyre-fixme[21] @@ -47,9 +47,30 @@ def _distance_squared(self, p2: tuple[float, float] | tuple[float, float, float] :param p2:""" return (self[0] - p2[0]) ** 2 + (self[1] - p2[1]) ** 2 + @overload + def sort_by_distance(self, ps: Units) -> Sequence[Unit]: ... + @overload + def sort_by_distance(self, ps: Iterable[Point3]) -> Sequence[Point3]: ... + @overload + def sort_by_distance(self, ps: Iterable[Point2]) -> Sequence[Point2]: ... + @overload + def sort_by_distance(self, ps: Iterable[tuple[float, float]]) -> Sequence[tuple[float, float]]: ... + @overload + def sort_by_distance(self, ps: Iterable[tuple[float, float, float]]) -> Sequence[tuple[float, float, float]]: ... def sort_by_distance( - self, ps: Units | Iterable[tuple[float, float] | tuple[float, float, float]] - ) -> list[Unit | tuple[float, float] | tuple[float, float, float]]: + self, + ps: Units + | Iterable[Point3] + | Iterable[Point2] + | Iterable[tuple[float, float]] + | Iterable[tuple[float, float, float]], + ) -> ( + Sequence[Unit] + | Sequence[Point3] + | Sequence[Point2] + | Sequence[tuple[float, float]] + | Sequence[tuple[float, float, float]] + ): """This returns the target points sorted as list. You should not pass a set or dict since those are not sortable. If you want to sort your units towards a point, use 'units.sorted_by_distance_to(point)' instead. @@ -57,7 +78,11 @@ def sort_by_distance( :param ps:""" return sorted(ps, key=lambda p: self.distance_to_point2(p.position)) - def closest(self, ps: Units | Iterable[Point2]) -> Unit | Pointlike: + @overload + def closest(self, ps: Units) -> Unit: ... + @overload + def closest(self, ps: Iterable[Point2]) -> Point2: ... + def closest(self, ps: Units | Iterable[Point2]) -> Unit | Point2: """This function assumes the 2d distance is meant :param ps:""" From d52659069be0d2289112d9dd2b359dd72b8ffc46 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 02:35:16 +0100 Subject: [PATCH 23/34] Add overloads for protocol._execute --- s2clientprotocol/sc2api_pb2.pyi | 2 +- sc2/main.py | 4 +-- sc2/protocol.py | 47 ++++++++++++++++++++++++++++++++- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/s2clientprotocol/sc2api_pb2.pyi b/s2clientprotocol/sc2api_pb2.pyi index 76def499..67574e00 100644 --- a/s2clientprotocol/sc2api_pb2.pyi +++ b/s2clientprotocol/sc2api_pb2.pyi @@ -1,7 +1,7 @@ from __future__ import annotations +from collections.abc import Iterable from enum import Enum -from typing import Iterable from google.protobuf.message import Message diff --git a/sc2/main.py b/sc2/main.py index 88e82344..8d07314e 100644 --- a/sc2/main.py +++ b/sc2/main.py @@ -232,7 +232,7 @@ async def _play_game( return result -async def _play_replay(client, ai, realtime: bool = False, player_id: int = 0): +async def _play_replay(client: Client, ai, realtime: bool = False, player_id: int = 0): ai._initialize_variables() game_data = await client.get_game_data() @@ -773,7 +773,7 @@ async def a_run_multiple_games_nokill(matches: list[GameMatch]) -> list[dict[Abs # Start the matches results = [] - controllers = [] + controllers: list[Controller] = [] for m in matches: logger.info(f"Starting match {1 + len(results)} / {len(matches)}: {m}") result = None diff --git a/sc2/protocol.py b/sc2/protocol.py index ce28608f..2722abe0 100644 --- a/sc2/protocol.py +++ b/sc2/protocol.py @@ -3,12 +3,14 @@ import asyncio import sys from contextlib import suppress +from typing import overload from aiohttp.client_ws import ClientWebSocketResponse from loguru import logger # pyre-fixme[21] from s2clientprotocol import sc2api_pb2 as sc_pb +from s2clientprotocol.query_pb2 import RequestQuery from sc2.data import Status @@ -64,7 +66,50 @@ async def __request(self, request: sc_pb.Request) -> sc_pb.Response: logger.debug("Response received") return response - # TODO Add types using @overload with various request args and response types + @overload + async def _execute(self, create_game: sc_pb.RequestCreateGame) -> sc_pb.Response: ... + @overload + async def _execute(self, join_game: sc_pb.RequestJoinGame) -> sc_pb.Response: ... + @overload + async def _execute(self, restart_game: sc_pb.RequestRestartGame) -> sc_pb.Response: ... + @overload + async def _execute(self, start_replay: sc_pb.RequestStartReplay) -> sc_pb.Response: ... + @overload + async def _execute(self, leave_game: sc_pb.RequestLeaveGame) -> sc_pb.Response: ... + @overload + async def _execute(self, quick_save: sc_pb.RequestQuickSave) -> sc_pb.Response: ... + @overload + async def _execute(self, quick_load: sc_pb.RequestQuickLoad) -> sc_pb.Response: ... + @overload + async def _execute(self, quit: sc_pb.RequestQuit) -> sc_pb.Response: ... + @overload + async def _execute(self, game_info: sc_pb.RequestGameInfo) -> sc_pb.Response: ... + @overload + async def _execute(self, action: sc_pb.RequestAction) -> sc_pb.Response: ... + @overload + async def _execute(self, observation: sc_pb.RequestObservation) -> sc_pb.Response: ... + @overload + async def _execute(self, obs_action: sc_pb.RequestObserverAction) -> sc_pb.Response: ... + @overload + async def _execute(self, step: sc_pb.RequestStep) -> sc_pb.Response: ... + @overload + async def _execute(self, data: sc_pb.RequestData) -> sc_pb.Response: ... + @overload + async def _execute(self, query: RequestQuery) -> sc_pb.Response: ... + @overload + async def _execute(self, save_replay: sc_pb.RequestSaveReplay) -> sc_pb.Response: ... + @overload + async def _execute(self, map_command: sc_pb.RequestMapCommand) -> sc_pb.Response: ... + @overload + async def _execute(self, replay_info: sc_pb.RequestReplayInfo) -> sc_pb.Response: ... + @overload + async def _execute(self, available_maps: sc_pb.RequestAvailableMaps) -> sc_pb.Response: ... + @overload + async def _execute(self, save_map: sc_pb.RequestSaveMap) -> sc_pb.Response: ... + @overload + async def _execute(self, ping: sc_pb.RequestPing) -> sc_pb.Response: ... + @overload + async def _execute(self, debug: sc_pb.RequestDebug) -> sc_pb.Response: ... async def _execute(self, **kwargs) -> sc_pb.Response: assert len(kwargs) == 1, "Only one request allowed by the API" From fddce9d1cb4049b99d6ba3e00eebcd160b07f1f8 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 02:48:49 +0100 Subject: [PATCH 24/34] Fix id_exists method --- sc2/game_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sc2/game_data.py b/sc2/game_data.py index 1a60963e..25f05433 100644 --- a/sc2/game_data.py +++ b/sc2/game_data.py @@ -78,7 +78,7 @@ class AbilityData: ability_ids: list[int] = [ability_id.value for ability_id in AbilityId][1:] # sorted list @classmethod - def id_exists(cls, ability_id: data_pb2.AbilityData | int) -> bool: + def id_exists(cls, ability_id: int) -> bool: assert isinstance(ability_id, int), f"Wrong type: {ability_id} is not int" if ability_id == 0: return False From c3db5652938da90f1700383c6841c548caf539c3 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 02:50:57 +0100 Subject: [PATCH 25/34] Fix ruff issues --- s2clientprotocol/data_pb2.pyi | 2 +- s2clientprotocol/debug_pb2.pyi | 2 +- s2clientprotocol/query_pb2.pyi | 3 ++- s2clientprotocol/raw_pb2.pyi | 2 +- s2clientprotocol/spatial_pb2.pyi | 2 +- s2clientprotocol/ui_pb2.pyi | 2 +- 6 files changed, 7 insertions(+), 6 deletions(-) diff --git a/s2clientprotocol/data_pb2.pyi b/s2clientprotocol/data_pb2.pyi index 00e286d8..8b839cf5 100644 --- a/s2clientprotocol/data_pb2.pyi +++ b/s2clientprotocol/data_pb2.pyi @@ -1,5 +1,5 @@ +from collections.abc import Iterable from enum import Enum -from typing import Iterable from google.protobuf.message import Message diff --git a/s2clientprotocol/debug_pb2.pyi b/s2clientprotocol/debug_pb2.pyi index 40f88db4..edc956ee 100644 --- a/s2clientprotocol/debug_pb2.pyi +++ b/s2clientprotocol/debug_pb2.pyi @@ -1,5 +1,5 @@ +from collections.abc import Iterable from enum import Enum -from typing import Iterable from google.protobuf.message import Message diff --git a/s2clientprotocol/query_pb2.pyi b/s2clientprotocol/query_pb2.pyi index 19761aee..9f575c97 100644 --- a/s2clientprotocol/query_pb2.pyi +++ b/s2clientprotocol/query_pb2.pyi @@ -1,4 +1,5 @@ -from typing import Iterable + +from collections.abc import Iterable from google.protobuf.message import Message diff --git a/s2clientprotocol/raw_pb2.pyi b/s2clientprotocol/raw_pb2.pyi index 840b1efd..34d89c6d 100644 --- a/s2clientprotocol/raw_pb2.pyi +++ b/s2clientprotocol/raw_pb2.pyi @@ -1,5 +1,5 @@ +from collections.abc import Iterable from enum import Enum -from typing import Iterable from google.protobuf.message import Message diff --git a/s2clientprotocol/spatial_pb2.pyi b/s2clientprotocol/spatial_pb2.pyi index 3e5925e8..a1b72a29 100644 --- a/s2clientprotocol/spatial_pb2.pyi +++ b/s2clientprotocol/spatial_pb2.pyi @@ -1,7 +1,7 @@ from __future__ import annotations +from collections.abc import Iterable from enum import Enum -from typing import Iterable from google.protobuf.message import Message diff --git a/s2clientprotocol/ui_pb2.pyi b/s2clientprotocol/ui_pb2.pyi index 589d1b33..dbf39f3b 100644 --- a/s2clientprotocol/ui_pb2.pyi +++ b/s2clientprotocol/ui_pb2.pyi @@ -1,7 +1,7 @@ from __future__ import annotations +from collections.abc import Iterable from enum import Enum -from typing import Iterable from google.protobuf.message import Message From 8160f3ff509d8f84665e1caaa26ae0470aa3c651 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 02:54:52 +0100 Subject: [PATCH 26/34] Fix autoformat issue --- s2clientprotocol/query_pb2.pyi | 1 - 1 file changed, 1 deletion(-) diff --git a/s2clientprotocol/query_pb2.pyi b/s2clientprotocol/query_pb2.pyi index 9f575c97..746d86d9 100644 --- a/s2clientprotocol/query_pb2.pyi +++ b/s2clientprotocol/query_pb2.pyi @@ -1,4 +1,3 @@ - from collections.abc import Iterable from google.protobuf.message import Message From 2ad03c70a96c155539ceb971002854053f9ad061 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 02:59:47 +0100 Subject: [PATCH 27/34] Fix zerg cost --- sc2/game_data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sc2/game_data.py b/sc2/game_data.py index 25f05433..4be84ee2 100644 --- a/sc2/game_data.py +++ b/sc2/game_data.py @@ -224,7 +224,6 @@ def unit_alias(self) -> UnitTypeId | None: return UnitTypeId(self._proto.unit_alias) @property - # pyre-ignore[11] def race(self) -> Race: return Race(self._proto.race) @@ -235,8 +234,7 @@ def cost(self) -> Cost: @property def cost_zerg_corrected(self) -> Cost: """This returns 25 for extractor and 200 for spawning pool instead of 75 and 250 respectively""" - # pyre-ignore[16] - if self.race.value == Race.Zerg and Attribute.Structure in self.attributes: + if self.race == Race.Zerg and Attribute.Structure.value in self._proto.attributes: return Cost(self._proto.mineral_cost - 50, self._proto.vespene_cost, self._proto.build_time) return self.cost From c78621e93bab5fbe85372817ea3be9aedf770d20 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 03:03:22 +0100 Subject: [PATCH 28/34] Revert using typing.Self --- sc2/position.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sc2/position.py b/sc2/position.py index 97735b22..53032192 100644 --- a/sc2/position.py +++ b/sc2/position.py @@ -5,7 +5,7 @@ import math import random from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Self, SupportsFloat, SupportsIndex, overload +from typing import TYPE_CHECKING, SupportsFloat, SupportsIndex, overload # pyre-fixme[21] from s2clientprotocol import common_pb2 as common_pb @@ -23,7 +23,7 @@ def _sign(num: SupportsFloat | SupportsIndex) -> float: class Pointlike(tuple[float, float]): @property - def position(self) -> Self: + def position(self) -> Pointlike: return self def distance_to(self, target: Unit | Pointlike) -> float: @@ -127,21 +127,21 @@ def distance_to_furthest(self, ps: Units | Iterable[Point2]) -> float: furthest_distance = distance return furthest_distance - def offset(self, p: tuple[float, float]) -> Self: + def offset(self, p: tuple[float, float]) -> Pointlike: """ :param p: """ return self.__class__(a + b for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0)) - def unit_axes_towards(self, p: tuple[float, float]) -> Self: + def unit_axes_towards(self, p: tuple[float, float]) -> Pointlike: """ :param p: """ return self.__class__(_sign(b - a) for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0)) - def towards(self, p: Unit | Pointlike, distance: int | float = 1, limit: bool = False) -> Self: + def towards(self, p: Unit | Pointlike, distance: int | float = 1, limit: bool = False) -> Pointlike: """ :param p: From 2042b1b7407b4ca61b5d7297f4f546771a3ee234 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 03:08:50 +0100 Subject: [PATCH 29/34] Fix annotation by importing from future --- sc2/score.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sc2/score.py b/sc2/score.py index 18df8f38..aba9c8ff 100644 --- a/sc2/score.py +++ b/sc2/score.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from s2clientprotocol import score_pb2 From dbf8e850b7c532e47be00f4eeefe26c5870120d1 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 03:14:12 +0100 Subject: [PATCH 30/34] Fix type hints for color --- sc2/client.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/sc2/client.py b/sc2/client.py index 6dfb8703..8971b86f 100644 --- a/sc2/client.py +++ b/sc2/client.py @@ -480,7 +480,7 @@ def debug_text_screen( self, text: str, pos: Point2 | Point3 | tuple[float, float] | list[float], - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, size: int = 8, ) -> None: """ @@ -501,7 +501,7 @@ def debug_text_2d( self, text: str, pos: Point2 | Point3 | tuple[float, float] | list[float], - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, size: int = 8, ): return self.debug_text_screen(text, pos, color, size) @@ -510,7 +510,7 @@ def debug_text_world( self, text: str, pos: Unit | Point3, - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, size: int = 8, ) -> None: """ @@ -531,7 +531,7 @@ def debug_text_3d( self, text: str, pos: Unit | Point3, - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, size: int = 8, ): return self.debug_text_world(text, pos, color, size) @@ -540,7 +540,7 @@ def debug_line_out( self, p0: Unit | Point3, p1: Unit | Point3, - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, ) -> None: """ Draws a line from p0 to p1. @@ -561,7 +561,7 @@ def debug_box_out( self, p_min: Unit | Point3, p_max: Unit | Point3, - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, ) -> None: """ Draws a box with p_min and p_max as corners of the box. @@ -582,7 +582,7 @@ def debug_box2_out( self, pos: Unit | Point3, half_vertex_length: float = 0.25, - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, ) -> None: """ Draws a box center at a position 'pos', with box side lengths (vertices) of two times 'half_vertex_length'. @@ -602,7 +602,7 @@ def debug_sphere_out( self, p: Unit | Point3, r: float, - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, ) -> None: """ Draws a sphere at point p with radius r. @@ -781,7 +781,7 @@ async def quick_load(self) -> None: class DrawItem: @staticmethod - def to_debug_color(color: tuple[float, float] | list[float] | Point3 | None = None) -> debug_pb.Color: + def to_debug_color(color: tuple[float, float, float] | list[float] | Point3 | None = None) -> debug_pb.Color: """Helper function for color conversion""" if color is None: return debug_pb.Color(r=255, g=255, b=255) @@ -805,7 +805,7 @@ class DrawItemScreenText(DrawItem): def __init__( self, start_point: Point2, - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, text: str = "", font_size: int = 8, ) -> None: @@ -831,7 +831,7 @@ class DrawItemWorldText(DrawItem): def __init__( self, start_point: Point3 = None, - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, text: str = "", font_size: int = 8, ) -> None: @@ -858,7 +858,7 @@ def __init__( self, start_point: Point3 = None, end_point: Point3 = None, - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, ) -> None: self._start_point: Point3 = start_point self._end_point: Point3 = end_point @@ -879,7 +879,7 @@ def __init__( self, start_point: Point3 = None, end_point: Point3 = None, - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, ) -> None: self._start_point: Point3 = start_point self._end_point: Point3 = end_point @@ -901,7 +901,7 @@ def __init__( self, start_point: Point3 = None, radius: float = None, - color: tuple[float, float] | list[float] | Point3 | None = None, + color: tuple[float, float, float] | list[float] | Point3 | None = None, ) -> None: self._start_point: Point3 = start_point self._radius: float = radius From 26b3b844a7e930d7184cc5302e94878c322c6c0a Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Tue, 18 Nov 2025 16:55:15 +0100 Subject: [PATCH 31/34] Replace IntEnum with Enum and clean up --- sc2/data.pyi | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sc2/data.pyi b/sc2/data.pyi index 57ad0c1a..4783460e 100644 --- a/sc2/data.pyi +++ b/sc2/data.pyi @@ -10,12 +10,11 @@ and mypy to understand the structure and members of these enums. from __future__ import annotations -from enum import Enum, IntEnum +from enum import Enum from sc2.ids.ability_id import AbilityId from sc2.ids.unit_typeid import UnitTypeId -# Enums created from sc2api_pb2 class CreateGameError(Enum): MissingMap = 1 InvalidMapPath = 2 @@ -26,7 +25,7 @@ class CreateGameError(Enum): InvalidPlayerSetup = 7 MultiplayerUnsupported = 8 -class PlayerType(IntEnum): +class PlayerType(Enum): Participant = 1 Computer = 2 Observer = 3 @@ -94,8 +93,7 @@ class ChatChannel(Enum): Broadcast = 1 Team = 2 -# Enums created from common_pb2 -class Race(IntEnum): +class Race(Enum): """StarCraft II race enum. Members: @@ -132,7 +130,6 @@ class CloakState(Enum): NotCloaked = 4 CloakedAllied = 5 -# Enums created from data_pb2 class Attribute(Enum): Light = 1 Armored = 2 @@ -160,7 +157,6 @@ class Target(Enum): PointOrUnit = 3 PointOrNone = 4 -# Enums created from error_pb2 class ActionResult(Enum): """Action result codes from game engine. From 1361944a291a9e3b5b57dd3c26b56f0e3e8e71be Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Thu, 20 Nov 2025 15:07:32 +0100 Subject: [PATCH 32/34] Simplify typing for position.py --- sc2/bot_ai_internal.py | 10 +- sc2/position.py | 203 +++++++++++++++++++---------------------- sc2/unit.py | 10 +- 3 files changed, 102 insertions(+), 121 deletions(-) diff --git a/sc2/bot_ai_internal.py b/sc2/bot_ai_internal.py index e7439212..d2bde3f7 100644 --- a/sc2/bot_ai_internal.py +++ b/sc2/bot_ai_internal.py @@ -33,7 +33,7 @@ from sc2.ids.unit_typeid import UnitTypeId from sc2.ids.upgrade_id import UpgradeId from sc2.pixel_map import PixelMap -from sc2.position import Point2 +from sc2.position import Point2, _PointLike from sc2.unit import Unit from sc2.unit_command import UnitCommand from sc2.units import Units @@ -1016,16 +1016,16 @@ def convert_tuple_to_numpy_array(pos: tuple[float, float]) -> np.ndarray: @final @staticmethod def distance_math_hypot( - p1: tuple[float, float] | Point2, - p2: tuple[float, float] | Point2, + p1: _PointLike, + p2: _PointLike, ) -> float: return math.hypot(p1[0] - p2[0], p1[1] - p2[1]) @final @staticmethod def distance_math_hypot_squared( - p1: tuple[float, float] | Point2, - p2: tuple[float, float] | Point2, + p1: _PointLike, + p2: _PointLike, ) -> float: return pow(p1[0] - p2[0], 2) + pow(p1[1] - p2[1], 2) diff --git a/sc2/position.py b/sc2/position.py index 53032192..fdbefb2e 100644 --- a/sc2/position.py +++ b/sc2/position.py @@ -1,18 +1,29 @@ -# pyre-ignore-all-errors[6, 14, 15, 58] from __future__ import annotations import itertools import math import random -from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, SupportsFloat, SupportsIndex, overload +from collections.abc import Iterable +from typing import ( + Any, + Protocol, + SupportsFloat, + SupportsIndex, + TypeVar, + Union, +) -# pyre-fixme[21] from s2clientprotocol import common_pb2 as common_pb -if TYPE_CHECKING: - from sc2.unit import Unit - from sc2.units import Units + +class HasPosition2D(Protocol): + @property + def position(self) -> Point2: ... + + +_PointLike = Union[tuple[float, float], tuple[float, float], tuple[float, ...]] +_PosLike = Union[HasPosition2D, _PointLike] +_TPosLike = TypeVar("_TPosLike", bound=_PosLike) EPSILON: float = 10**-8 @@ -21,147 +32,118 @@ def _sign(num: SupportsFloat | SupportsIndex) -> float: return math.copysign(1, num) -class Pointlike(tuple[float, float]): +class Pointlike(tuple[float, ...]): + T = TypeVar("T", bound="Pointlike") + @property - def position(self) -> Pointlike: + def position(self: T) -> T: return self - def distance_to(self, target: Unit | Pointlike) -> float: + def distance_to(self, target: _PosLike) -> float: """Calculate a single distance from a point or unit to another point or unit :param target:""" - p = target.position + p: Point2 | Point3 = target if hasattr(target, "position") else target # pyright: ignore[reportAssignmentType] return math.hypot(self[0] - p[0], self[1] - p[1]) - def distance_to_point2(self, p: tuple[float, float] | tuple[float, float, float]) -> float: + def distance_to_point2(self, p: _PointLike) -> float: """Same as the function above, but should be a bit faster because of the dropped asserts and conversion. :param p:""" return math.hypot(self[0] - p[0], self[1] - p[1]) - def _distance_squared(self, p2: tuple[float, float] | tuple[float, float, float]) -> float: + def _distance_squared(self, p2: _PointLike) -> float: """Function used to not take the square root as the distances will stay proportionally the same. This is to speed up the sorting process. :param p2:""" return (self[0] - p2[0]) ** 2 + (self[1] - p2[1]) ** 2 - @overload - def sort_by_distance(self, ps: Units) -> Sequence[Unit]: ... - @overload - def sort_by_distance(self, ps: Iterable[Point3]) -> Sequence[Point3]: ... - @overload - def sort_by_distance(self, ps: Iterable[Point2]) -> Sequence[Point2]: ... - @overload - def sort_by_distance(self, ps: Iterable[tuple[float, float]]) -> Sequence[tuple[float, float]]: ... - @overload - def sort_by_distance(self, ps: Iterable[tuple[float, float, float]]) -> Sequence[tuple[float, float, float]]: ... - def sort_by_distance( - self, - ps: Units - | Iterable[Point3] - | Iterable[Point2] - | Iterable[tuple[float, float]] - | Iterable[tuple[float, float, float]], - ) -> ( - Sequence[Unit] - | Sequence[Point3] - | Sequence[Point2] - | Sequence[tuple[float, float]] - | Sequence[tuple[float, float, float]] - ): + def sort_by_distance(self, ps: Iterable[_TPosLike]) -> list[_TPosLike]: """This returns the target points sorted as list. You should not pass a set or dict since those are not sortable. If you want to sort your units towards a point, use 'units.sorted_by_distance_to(point)' instead. :param ps:""" - return sorted(ps, key=lambda p: self.distance_to_point2(p.position)) + return sorted(ps, key=lambda p: self.distance_to_point2(p.position if hasattr(p, "position") else p)) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportArgumentType] - @overload - def closest(self, ps: Units) -> Unit: ... - @overload - def closest(self, ps: Iterable[Point2]) -> Point2: ... - def closest(self, ps: Units | Iterable[Point2]) -> Unit | Point2: + def closest(self, ps: Iterable[_TPosLike]) -> _TPosLike: """This function assumes the 2d distance is meant :param ps:""" assert ps, "ps is empty" - return min(ps, key=lambda p: self.distance_to(p)) + return min(ps, key=lambda p: self.distance_to_point2(p if hasattr(p, "position") else p)) # pyright: ignore[reportArgumentType] - def distance_to_closest(self, ps: Units | Iterable[Point2]) -> float: + def distance_to_closest(self, ps: Iterable[_TPosLike]) -> float: """This function assumes the 2d distance is meant :param ps:""" assert ps, "ps is empty" closest_distance = math.inf - for p2 in ps: - p2 = p2.position - distance = self.distance_to(p2) + for p in ps: + p2: Point2 | Point3 = p if hasattr(p, "position") else p # pyright: ignore[reportAssignmentType] + distance = self.distance_to_point2(p2) if distance <= closest_distance: closest_distance = distance return closest_distance - @overload - def furthest(self, ps: Units) -> Unit: ... - @overload - def furthest(self, ps: Iterable[Point2]) -> Point2: ... - def furthest(self, ps: Units | Iterable[Point2]) -> Unit | Point2: + def furthest(self, ps: Iterable[_TPosLike]) -> _TPosLike: """This function assumes the 2d distance is meant :param ps: Units object, or iterable of Unit or Point2""" assert ps, "ps is empty" - return max(ps, key=lambda p: self.distance_to(p)) + return max(ps, key=lambda p: self.distance_to_point2(p if hasattr(p, "position") else p)) # pyright: ignore[reportArgumentType] - def distance_to_furthest(self, ps: Units | Iterable[Point2]) -> float: + def distance_to_furthest(self, ps: Iterable[_PosLike]) -> float: """This function assumes the 2d distance is meant :param ps:""" assert ps, "ps is empty" furthest_distance = -math.inf - for p2 in ps: - p2 = p2.position - distance = self.distance_to(p2) + for p in ps: + p2: Point2 | Point3 = p if hasattr(p, "position") else p # pyright: ignore[reportAssignmentType] + distance = self.distance_to_point2(p2) if distance >= furthest_distance: furthest_distance = distance return furthest_distance - def offset(self, p: tuple[float, float]) -> Pointlike: + def offset(self: T, p: _PointLike) -> T: """ :param p: """ return self.__class__(a + b for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0)) - def unit_axes_towards(self, p: tuple[float, float]) -> Pointlike: + def unit_axes_towards(self: T, p: _PointLike) -> T: """ :param p: """ return self.__class__(_sign(b - a) for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0)) - def towards(self, p: Unit | Pointlike, distance: int | float = 1, limit: bool = False) -> Pointlike: + def towards(self: T, p: _PosLike, distance: float = 1, limit: bool = False) -> T: """ :param p: :param distance: :param limit: """ - p = p.position + p2: Point2 | Point3 = p if hasattr(p, "position") else p # pyright: ignore[reportAssignmentType] # assert self != p, f"self is {self}, p is {p}" # TODO test and fix this if statement - if self == p: + if self == p2: return self # end of test - d = self.distance_to(p) + d = self.distance_to_point2(p2) if limit: distance = min(d, distance) return self.__class__( - a + (b - a) / d * distance for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0) + a + (b - a) / d * distance for a, b in itertools.zip_longest(self, p2[: len(self)], fillvalue=0) ) - def __eq__(self, other: tuple[float, float] | tuple[float, float, float]) -> bool: + def __eq__(self, other: Any) -> bool: try: return all(abs(a - b) <= EPSILON for a, b in itertools.zip_longest(self, other, fillvalue=0)) except TypeError: @@ -172,6 +154,8 @@ def __hash__(self) -> int: class Point2(Pointlike): + T = TypeVar("T", bound="Point2") + @classmethod def from_proto( cls, data: common_pb.Point | common_pb.Point2D | common_pb.Size2DI | common_pb.PointI | Point2 | Point3 @@ -182,7 +166,6 @@ def from_proto( return cls((data.x, data.y)) @property - # pyre-fixme[11] def as_Point2D(self) -> common_pb.Point2D: return common_pb.Point2D(x=self.x, y=self.y) @@ -201,12 +184,12 @@ def length(self) -> float: return math.hypot(self[0], self[1]) @property - def normalized(self) -> Point2: + def normalized(self: Point2 | Point3) -> Point2: """This property exists in case Point2 is used as a vector.""" length = self.length # Cannot normalize if length is zero assert length - return self.__class__((self[0] / length, self[1] / length)) + return Point2((self[0] / length, self[1] / length)) @property def x(self) -> float: @@ -228,8 +211,8 @@ def round(self, decimals: int) -> Point2: """Rounds each number in the tuple to the amount of given decimals.""" return Point2((round(self[0], decimals), round(self[1], decimals))) - def offset(self, p: tuple[float, float]) -> Point2: - return Point2((self[0] + p[0], self[1] + p[1])) + def offset(self: T, p: _PointLike) -> T: + return self.__class__((self[0] + p[0], self[1] + p[1])) def random_on_distance(self, distance: float | tuple[float, float] | list[float]) -> Point2: if isinstance(distance, (tuple, list)): # interval @@ -253,7 +236,7 @@ def towards_with_random_angle( angle = (angle - max_difference) + max_difference * 2 * random.random() return Point2((self.x + math.cos(angle) * distance, self.y + math.sin(angle) * distance)) - def circle_intersection(self, p: Point2, r: int | float) -> set[Point2]: + def circle_intersection(self, p: Point2, r: float) -> set[Point2]: """self is point1, p is point2, r is the radius for circles originating in both points Used in ramp finding @@ -281,68 +264,66 @@ def circle_intersection(self, p: Point2, r: int | float) -> set[Point2]: return {intersect1, intersect2} @property - def neighbors4(self) -> set[Point2]: + def neighbors4(self: T) -> set[T]: return { - Point2((self.x - 1, self.y)), - Point2((self.x + 1, self.y)), - Point2((self.x, self.y - 1)), - Point2((self.x, self.y + 1)), + self.__class__((self[0] - 1, self[1])), + self.__class__((self[0] + 1, self[1])), + self.__class__((self[0], self[1] - 1)), + self.__class__((self[0], self[1] + 1)), } @property - def neighbors8(self) -> set[Point2]: + def neighbors8(self: T) -> set[T]: return self.neighbors4 | { - Point2((self.x - 1, self.y - 1)), - Point2((self.x - 1, self.y + 1)), - Point2((self.x + 1, self.y - 1)), - Point2((self.x + 1, self.y + 1)), + self.__class__((self[0] - 1, self[1] - 1)), + self.__class__((self[0] - 1, self[1] + 1)), + self.__class__((self[0] + 1, self[1] - 1)), + self.__class__((self[0] + 1, self[1] + 1)), } - def negative_offset(self, other: Point2) -> Point2: + def negative_offset(self: T, other: Point2) -> T: return self.__class__((self[0] - other[0], self[1] - other[1])) - def __add__(self, other: Point2) -> Point2: + def __add__(self, other: Point2) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] return self.offset(other) def __sub__(self, other: Point2) -> Point2: return self.negative_offset(other) - def __neg__(self) -> Point2: + def __neg__(self: T) -> T: return self.__class__(-a for a in self) def __abs__(self) -> float: - return math.hypot(self.x, self.y) + return math.hypot(self[0], self[1]) def __bool__(self) -> bool: - return self.x != 0 or self.y != 0 + return self[0] != 0 or self[1] != 0 - def __mul__(self, other: int | float | Point2) -> Point2: - try: - # pyre-ignore[16] - return self.__class__((self.x * other.x, self.y * other.y)) - except AttributeError: - return self.__class__((self.x * other, self.y * other)) + def __mul__(self, other: _PointLike | float) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] + if isinstance(other, (int, float)): + return Point2((self[0] * other, self[1] * other)) + return Point2((self[0] * other[0], self[1] * other[1])) - def __rmul__(self, other: int | float | Point2) -> Point2: + def __rmul__(self, other: _PointLike | float) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] return self.__mul__(other) - def __truediv__(self, other: int | float | Point2) -> Point2: - if isinstance(other, self.__class__): - return self.__class__((self.x / other.x, self.y / other.y)) - return self.__class__((self.x / other, self.y / other)) + def __truediv__(self, other: float | Point2) -> Point2: + if isinstance(other, (int, float)): + return self.__class__((self[0] / other, self[1] / other)) + return self.__class__((self[0] / other[0], self[1] / other[1])) def is_same_as(self, other: Point2, dist: float = 0.001) -> bool: return self.distance_to_point2(other) <= dist def direction_vector(self, other: Point2) -> Point2: """Converts a vector to a direction that can face vertically, horizontally or diagonal or be zero, e.g. (0, 0), (1, -1), (1, 0)""" - return self.__class__((_sign(other.x - self.x), _sign(other.y - self.y))) + return self.__class__((_sign(other[0] - self[0]), _sign(other[1] - self[1]))) def manhattan_distance(self, other: Point2) -> float: """ :param other: """ - return abs(other.x - self.x) + abs(other.y - self.y) + return abs(other[0] - self[0]) + abs(other[1] - self[1]) @staticmethod def center(points: list[Point2]) -> Point2: @@ -357,14 +338,13 @@ def center(points: list[Point2]) -> Point2: class Point3(Point2): @classmethod - def from_proto(cls, data: common_pb.Point | Point3) -> Point3: + def from_proto(cls, data: common_pb.Point | Point3) -> Point3: # pyright: ignore[reportIncompatibleMethodOverride] """ :param data: """ return cls((data.x, data.y, data.z)) @property - # pyre-fixme[11] def as_Point(self) -> common_pb.Point: return common_pb.Point(x=self.x, y=self.y, z=self.z) @@ -381,15 +361,16 @@ def to3(self) -> Point3: return Point3(self) def __add__(self, other: Point2 | Point3) -> Point3: - if not isinstance(other, Point3) and isinstance(other, Point2): - return Point3((self.x + other.x, self.y + other.y, self.z)) - # pyre-ignore[16] - return Point3((self.x + other.x, self.y + other.y, self.z + other.z)) + if not isinstance(other, Point3): + return Point3((self[0] + other[0], self[1] + other[1], self[2])) + return Point3((self[0] + other[0], self[1] + other[1], self[2] + other[2])) class Size(Point2): @classmethod - def from_proto(cls, data: common_pb.Point | common_pb.Point2D | common_pb.Size2DI | common_pb.PointI) -> Size: + def from_proto( + cls, data: common_pb.Point | common_pb.Point2D | common_pb.Size2DI | common_pb.PointI | Point2 + ) -> Size: """ :param data: """ @@ -404,9 +385,9 @@ def height(self) -> float: return self[1] -class Rect(tuple[float, float, float, float]): +class Rect(Point2): @classmethod - def from_proto(cls, data: common_pb.RectangleI) -> Rect: + def from_proto(cls, data: common_pb.RectangleI) -> Rect: # pyright: ignore[reportIncompatibleMethodOverride] """ :param data: """ @@ -444,8 +425,8 @@ def size(self) -> Size: return Size((self[2], self[3])) @property - def center(self) -> Point2: + def center(self) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] return Point2((self.x + self.width / 2, self.y + self.height / 2)) - def offset(self, p: tuple[float, float]) -> Rect: + def offset(self, p: _PointLike) -> Rect: return self.__class__((self[0] + p[0], self[1] + p[1], self[2], self[3])) diff --git a/sc2/unit.py b/sc2/unit.py index 9da68762..236116f9 100644 --- a/sc2/unit.py +++ b/sc2/unit.py @@ -58,7 +58,7 @@ from sc2.ids.buff_id import BuffId from sc2.ids.unit_typeid import UnitTypeId from sc2.ids.upgrade_id import UpgradeId -from sc2.position import Point2, Point3 +from sc2.position import HasPosition2D, Point2, Point3, _PointLike from sc2.unit_command import UnitCommand if TYPE_CHECKING: @@ -103,7 +103,7 @@ def __repr__(self) -> str: return f"UnitOrder({self.ability}, {self.target}, {self.progress})" -class Unit: +class Unit(HasPosition2D): class_cache = CacheDict() def __init__( @@ -529,7 +529,7 @@ def position_tuple(self) -> tuple[float, float]: return self._proto.pos.x, self._proto.pos.y @cached_property - def position(self) -> Point2: + def position(self) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] """Returns the 2d position of the unit.""" return Point2.from_proto(self._proto.pos) @@ -538,7 +538,7 @@ def position3d(self) -> Point3: """Returns the 3d position of the unit.""" return Point3.from_proto(self._proto.pos) - def distance_to(self, p: Unit | Point2) -> float: + def distance_to(self, p: Unit | _PointLike) -> float: """Using the 2d distance between self and p. To calculate the 3d distance, use unit.position3d.distance_to(p) @@ -548,7 +548,7 @@ def distance_to(self, p: Unit | Point2) -> float: return self._bot_object._distance_squared_unit_to_unit(self, p) ** 0.5 return self._bot_object.distance_math_hypot(self.position_tuple, p) - def distance_to_squared(self, p: Unit | Point2) -> float: + def distance_to_squared(self, p: Unit | _PointLike) -> float: """Using the 2d distance squared between self and p. Slightly faster than distance_to, so when filtering a lot of units, this function is recommended to be used. To calculate the 3d distance, use unit.position3d.distance_to(p) From dab178654f3937b08772a44f7a00089827b58ef7 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Thu, 20 Nov 2025 15:15:54 +0100 Subject: [PATCH 33/34] Fix distance_to and refactor hasattr to isinstance --- sc2/position.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sc2/position.py b/sc2/position.py index fdbefb2e..f3d9bd70 100644 --- a/sc2/position.py +++ b/sc2/position.py @@ -43,7 +43,7 @@ def distance_to(self, target: _PosLike) -> float: """Calculate a single distance from a point or unit to another point or unit :param target:""" - p: Point2 | Point3 = target if hasattr(target, "position") else target # pyright: ignore[reportAssignmentType] + p: tuple[float, ...] = target if isinstance(target, tuple) else target.position return math.hypot(self[0] - p[0], self[1] - p[1]) def distance_to_point2(self, p: _PointLike) -> float: @@ -66,7 +66,7 @@ def sort_by_distance(self, ps: Iterable[_TPosLike]) -> list[_TPosLike]: If you want to sort your units towards a point, use 'units.sorted_by_distance_to(point)' instead. :param ps:""" - return sorted(ps, key=lambda p: self.distance_to_point2(p.position if hasattr(p, "position") else p)) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportArgumentType] + return sorted(ps, key=lambda p: self.distance_to_point2(p if isinstance(p, tuple) else p.position)) def closest(self, ps: Iterable[_TPosLike]) -> _TPosLike: """This function assumes the 2d distance is meant @@ -74,7 +74,7 @@ def closest(self, ps: Iterable[_TPosLike]) -> _TPosLike: :param ps:""" assert ps, "ps is empty" - return min(ps, key=lambda p: self.distance_to_point2(p if hasattr(p, "position") else p)) # pyright: ignore[reportArgumentType] + return min(ps, key=lambda p: self.distance_to_point2(p if isinstance(p, tuple) else p.position)) def distance_to_closest(self, ps: Iterable[_TPosLike]) -> float: """This function assumes the 2d distance is meant @@ -82,7 +82,7 @@ def distance_to_closest(self, ps: Iterable[_TPosLike]) -> float: assert ps, "ps is empty" closest_distance = math.inf for p in ps: - p2: Point2 | Point3 = p if hasattr(p, "position") else p # pyright: ignore[reportAssignmentType] + p2: tuple[float, ...] = p if isinstance(p, tuple) else p.position distance = self.distance_to_point2(p2) if distance <= closest_distance: closest_distance = distance @@ -94,7 +94,7 @@ def furthest(self, ps: Iterable[_TPosLike]) -> _TPosLike: :param ps: Units object, or iterable of Unit or Point2""" assert ps, "ps is empty" - return max(ps, key=lambda p: self.distance_to_point2(p if hasattr(p, "position") else p)) # pyright: ignore[reportArgumentType] + return max(ps, key=lambda p: self.distance_to_point2(p if isinstance(p, tuple) else p.position)) def distance_to_furthest(self, ps: Iterable[_PosLike]) -> float: """This function assumes the 2d distance is meant @@ -103,7 +103,7 @@ def distance_to_furthest(self, ps: Iterable[_PosLike]) -> float: assert ps, "ps is empty" furthest_distance = -math.inf for p in ps: - p2: Point2 | Point3 = p if hasattr(p, "position") else p # pyright: ignore[reportAssignmentType] + p2: tuple[float, ...] = p if isinstance(p, tuple) else p.position distance = self.distance_to_point2(p2) if distance >= furthest_distance: furthest_distance = distance @@ -130,7 +130,7 @@ def towards(self: T, p: _PosLike, distance: float = 1, limit: bool = False) -> T :param distance: :param limit: """ - p2: Point2 | Point3 = p if hasattr(p, "position") else p # pyright: ignore[reportAssignmentType] + p2: tuple[float, ...] = p if isinstance(p, tuple) else p.position # assert self != p, f"self is {self}, p is {p}" # TODO test and fix this if statement if self == p2: From 6658f1cc88f3ce66f196b8799c2ddc54189bdc88 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Thu, 20 Nov 2025 15:27:37 +0100 Subject: [PATCH 34/34] Add s2clientprotocol folder to docker ci --- .github/workflows/docker-ci.yml | 1 + dockerfiles/test_docker_image.sh | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/docker-ci.yml b/.github/workflows/docker-ci.yml index f028aaa8..65270e9d 100644 --- a/.github/workflows/docker-ci.yml +++ b/.github/workflows/docker-ci.yml @@ -111,6 +111,7 @@ jobs: docker cp pyproject.toml test_container:/root/python-sc2/ docker cp uv.lock test_container:/root/python-sc2/ docker cp sc2 test_container:/root/python-sc2/sc2 + docker cp s2clientprotocol test_container:/root/python-sc2/s2clientprotocol docker cp test test_container:/root/python-sc2/test docker cp examples test_container:/root/python-sc2/examples docker exec -i test_container bash -c "pip install uv \ diff --git a/dockerfiles/test_docker_image.sh b/dockerfiles/test_docker_image.sh index 4b203c2e..7c10be5f 100644 --- a/dockerfiles/test_docker_image.sh +++ b/dockerfiles/test_docker_image.sh @@ -46,6 +46,7 @@ docker cp uv.lock test_container:/root/python-sc2/ docker exec -i test_container bash -c "pip install uv && cd python-sc2 && uv sync --no-cache --no-install-project" docker cp sc2 test_container:/root/python-sc2/sc2 +docker cp s2clientprotocol test_container:/root/python-sc2/s2clientprotocol docker cp test test_container:/root/python-sc2/test # Run various test bots