diff --git a/test_normalization.py b/test_normalization.py new file mode 100644 index 0000000..934cfa7 --- /dev/null +++ b/test_normalization.py @@ -0,0 +1,25 @@ +import torch +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent)) +from utils.normalization import ObservationNormalizer + + +def test_observation_normalizer_handles_non_contiguous_input(): + normalizer = ObservationNormalizer((2,), device="cpu") + obs = torch.randn(10, 2) + non_contig = obs[::2] + assert not non_contig.is_contiguous() + + normalizer.update(non_contig) + + # Compare with contiguous version to ensure statistics are correct + contig = non_contig.contiguous() + ref_normalizer = ObservationNormalizer((2,), device="cpu") + ref_normalizer.update(contig) + + assert torch.allclose(normalizer.obs_rms.mean, ref_normalizer.obs_rms.mean) + + normalized = normalizer.normalize(non_contig) + assert normalized.shape == non_contig.shape diff --git a/utils/normalization.py b/utils/normalization.py index 6ba5d72..2b32d6d 100644 --- a/utils/normalization.py +++ b/utils/normalization.py @@ -149,7 +149,7 @@ def __init__(self, observation_shape: Tuple[int, ...], device: str = 'cpu'): def update(self, observations: torch.Tensor): """Update normalization statistics""" # Flatten batch dimensions for update - flat_obs = observations.view(-1, *observations.shape[-len(self.obs_rms.mean.shape):]) + flat_obs = observations.reshape(-1, *observations.shape[-len(self.obs_rms.mean.shape):]) self.obs_rms.update(flat_obs) def normalize(self, observations: torch.Tensor) -> torch.Tensor: @@ -158,4 +158,4 @@ def normalize(self, observations: torch.Tensor) -> torch.Tensor: def denormalize(self, observations: torch.Tensor) -> torch.Tensor: """Denormalize observations""" - return self.obs_rms.denormalize(observations) \ No newline at end of file + return self.obs_rms.denormalize(observations)