diff --git a/src/Network/SerializationPatches.cs b/src/Network/SerializationPatches.cs index e6804f7..d324fe1 100644 --- a/src/Network/SerializationPatches.cs +++ b/src/Network/SerializationPatches.cs @@ -1,9 +1,10 @@ +using System; using System.Collections.Generic; using System.Linq; using System.Reflection; using HarmonyLib; using MegaCrit.Sts2.Core.Entities.Multiplayer; -using MegaCrit.Sts2.Core.Multiplayer.Messages.Lobby; +using MegaCrit.Sts2.Core.Logging; using MegaCrit.Sts2.Core.Multiplayer.Serialization; namespace RemoveMultiplayerPlayerLimit.Network; @@ -13,10 +14,9 @@ namespace RemoveMultiplayerPlayerLimit.Network; // ║ ║ // ║ 修改官方协议消息的 SlotId / LobbyList 序列化位宽: ║ // ║ • LobbyPlayer.slotId : 2 → 4 bits ║ -// ║ • ClientLobbyJoinResponse.list : 3 → 5 bits ║ -// ║ • LobbyBeginRunMessage.list : 3 → 5 bits ║ +// ║ • Lobby message player lists : 3 → 5 bits ║ // ║ ║ -// ║ 每个消息的 Serialize / Deserialize 各一个补丁,成对保证位宽一致。 ║ +// ║ Lobby 列表补丁扫描官方 Lobby 消息,避免漏掉 8 人时才触发的消息。 ║ // ╚══════════════════════════════════════════════════════════════════════════╝ /// @@ -68,46 +68,89 @@ private static IEnumerable Transpiler(IEnumerable Transpiler(IEnumerable instructions) - => TranspilerUtils.ReplaceBitWidthBeforeCall(instructions, + private static IEnumerable TargetMethods() + => LobbyMessagePatchTargets.GetPacketMethods(nameof(IPacketSerializable.Serialize), typeof(PacketWriter)); + + private static IEnumerable Transpiler(IEnumerable instructions, MethodBase __originalMethod) + => LobbyMessagePatchTargets.ReplaceLobbyListBits(instructions, SerializationMethods.WriteListWithBits, - ProtocolConfig.VanillaLobbyListLengthBits, ProtocolConfig.LobbyListLengthBits, - nameof(ClientLobbyJoinResponseSerializePatch)); + nameof(LobbyMessageSerializeListPatch), + __originalMethod); } -[HarmonyPatch(typeof(ClientLobbyJoinResponseMessage), nameof(ClientLobbyJoinResponseMessage.Deserialize))] -internal static class ClientLobbyJoinResponseDeserializePatch +[HarmonyPatch] +internal static class LobbyMessageDeserializeListPatch { - private static IEnumerable Transpiler(IEnumerable instructions) - => TranspilerUtils.ReplaceBitWidthBeforeCall(instructions, + private static IEnumerable TargetMethods() + => LobbyMessagePatchTargets.GetPacketMethods(nameof(IPacketSerializable.Deserialize), typeof(PacketReader)); + + private static IEnumerable Transpiler(IEnumerable instructions, MethodBase __originalMethod) + => LobbyMessagePatchTargets.ReplaceLobbyListBits(instructions, SerializationMethods.ReadListWithBits, - ProtocolConfig.VanillaLobbyListLengthBits, ProtocolConfig.LobbyListLengthBits, - nameof(ClientLobbyJoinResponseDeserializePatch)); + nameof(LobbyMessageDeserializeListPatch), + __originalMethod); } -// ── LobbyBeginRunMessage ─────────────────────────────────────────────── - -[HarmonyPatch(typeof(LobbyBeginRunMessage), nameof(LobbyBeginRunMessage.Serialize))] -internal static class LobbyBeginRunSerializePatch +internal static class LobbyMessagePatchTargets { - private static IEnumerable Transpiler(IEnumerable instructions) - => TranspilerUtils.ReplaceBitWidthBeforeCall(instructions, - SerializationMethods.WriteListWithBits, - ProtocolConfig.VanillaLobbyListLengthBits, ProtocolConfig.LobbyListLengthBits, - nameof(LobbyBeginRunSerializePatch)); -} + private const string LobbyMessagesNamespace = "MegaCrit.Sts2.Core.Multiplayer.Messages.Lobby"; -[HarmonyPatch(typeof(LobbyBeginRunMessage), nameof(LobbyBeginRunMessage.Deserialize))] -internal static class LobbyBeginRunDeserializePatch -{ - private static IEnumerable Transpiler(IEnumerable instructions) - => TranspilerUtils.ReplaceBitWidthBeforeCall(instructions, - SerializationMethods.ReadListWithBits, - ProtocolConfig.VanillaLobbyListLengthBits, ProtocolConfig.LobbyListLengthBits, - nameof(LobbyBeginRunDeserializePatch)); + internal static IEnumerable GetPacketMethods(string methodName, Type packetType) + { + foreach (Type type in GetSts2Types()) + { + if (type.Namespace != LobbyMessagesNamespace || type.IsAbstract) + { + continue; + } + MethodInfo? method = type.GetMethod(methodName, + BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, + null, + new[] { packetType }, + null); + if (method == null || method.IsAbstract || method.DeclaringType != type) + { + continue; + } + yield return method; + } + } + + internal static IEnumerable ReplaceLobbyListBits( + IEnumerable instructions, + MethodInfo? targetMethod, + string patchName, + MethodBase originalMethod) + { + IEnumerable result = TranspilerUtils.ReplaceBitWidthBeforeCall(instructions, + targetMethod, + ProtocolConfig.VanillaLobbyListLengthBits, + ProtocolConfig.LobbyListLengthBits, + patchName, + requireReplacement: false, + out int replacementCount); + + if (replacementCount > 0) + { + Log.Info($"{patchName}: widened {replacementCount} lobby list bit-width operand(s) in {originalMethod.DeclaringType?.Name}.{originalMethod.Name}."); + } + return result; + } + + private static IEnumerable GetSts2Types() + { + try + { + return typeof(LobbyPlayer).Assembly.GetTypes(); + } + catch (ReflectionTypeLoadException ex) + { + return ex.Types.Where(type => type != null).Cast(); + } + } } diff --git a/src/Network/TranspilerUtils.cs b/src/Network/TranspilerUtils.cs index 598581a..8201407 100644 --- a/src/Network/TranspilerUtils.cs +++ b/src/Network/TranspilerUtils.cs @@ -37,6 +37,22 @@ internal static IEnumerable ReplaceBitWidthBeforeCall( int sourceBitWidth, int targetBitWidth, string patchName) + => ReplaceBitWidthBeforeCall(instructions, + targetMethod, + sourceBitWidth, + targetBitWidth, + patchName, + requireReplacement: true, + out _); + + internal static IEnumerable ReplaceBitWidthBeforeCall( + IEnumerable instructions, + MethodInfo? targetMethod, + int sourceBitWidth, + int targetBitWidth, + string patchName, + bool requireReplacement, + out int replacementCount) { MethodInfo resolvedTargetMethod = targetMethod ?? throw new InvalidOperationException($"{patchName}: target method is null."); @@ -59,7 +75,9 @@ internal static IEnumerable ReplaceBitWidthBeforeCall( count++; } - if (count == 0) + replacementCount = count; + + if (requireReplacement && count == 0) { throw new InvalidOperationException( $"{patchName}: no bit-width operand replaced for method " +