diff --git a/engram_demo_v1.py b/engram_demo_v1.py index f3ce993..43d9f16 100644 --- a/engram_demo_v1.py +++ b/engram_demo_v1.py @@ -408,7 +408,7 @@ def forward(self,input_ids,hidden_states): for idx, layer in enumerate(LLM): if idx == 0: - hidden_states = LLM[0](input_ids) + hidden_states = layer(input_ids) ## mock hyper-connection hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, backbone_config.hc_mult, -1) elif idx == len(LLM)-1: @@ -420,4 +420,4 @@ def forward(self,input_ids,hidden_states): print("✅ Forward Complete!") print(f"{input_ids.shape=}\n{output.shape=}") - \ No newline at end of file +