Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Lint

on:
push:
branches: [main]
pull_request:

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- run: uv python install 3.10
- run: uv pip install ruff
- run: uv run ruff check .
- run: uv run ruff format --check .
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
6 changes: 3 additions & 3 deletions examples/run_lightmem_ollama.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import json
import ollama
import os
import time
from tqdm import tqdm
from typing import Dict, List, Optional

from lightmem.memory.lightmem import LightMemory
import ollama
from tqdm import tqdm

from lightmem.memory.lightmem import LightMemory

# =========== Ollama Configuration ============
your_ollama_model_name = "your_Ollama_model_name_01" # such as "llama3:latest"
Expand Down
3 changes: 1 addition & 2 deletions examples/run_lightmem_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer

from lightmem.memory.lightmem import LightMemory


# =========== Transformers Model Configuration ============
your_model_path_or_name = "your_model_path_or_name_01" # specify the model's path or name
your_model_device = "cuda:1" # specify the GPU device
Expand Down
51 changes: 28 additions & 23 deletions experiments/locomo/add_locomo.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from openai import OpenAI
import json
from tqdm import tqdm
import argparse
import datetime
import time
import os
import json
import logging
from lightmem.memory.lightmem import LightMemory
from lightmem.configs.retriever.embeddingretriever.qdrant import QdrantConfig
from lightmem.factory.retriever.embeddingretriever.qdrant import Qdrant
from prompts import METADATA_GENERATE_PROMPT_locomo, LoCoMo_Event_Binding_factual, LoCoMo_Event_Binding_relational
import sqlite3
import multiprocessing as mp
import os
import shutil
import argparse
import sqlite3
import time
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import multiprocessing as mp

from prompts import (
LoCoMo_Event_Binding_factual,
LoCoMo_Event_Binding_relational,
METADATA_GENERATE_PROMPT_locomo,
)
from tqdm import tqdm

from lightmem.configs.retriever.embeddingretriever.qdrant import QdrantConfig
from lightmem.factory.retriever.embeddingretriever.qdrant import Qdrant
from lightmem.memory.lightmem import LightMemory

# ============ Configuration ============
LOGS_ROOT = "./logs"
Expand Down Expand Up @@ -357,7 +362,7 @@ def process_single_sample(sample, api_key, args):
backup_start_time = time.time()

if os.path.exists(backup_dir):
logger.info(f" Removing existing backup...")
logger.info(" Removing existing backup...")
shutil.rmtree(backup_dir)

logger.info(f" Copying: {source_dir} -> {backup_dir}")
Expand Down Expand Up @@ -389,7 +394,7 @@ def process_single_sample(sample, api_key, args):
initial_summarize_tokens = initial_summarize_stats['llm']['summarize']['total_tokens']
initial_summarize_calls = initial_summarize_stats['llm']['summarize']['calls']

logger.info(f" Creating LightMemory instance for summarization (using pre_update)")
logger.info(" Creating LightMemory instance for summarization (using pre_update)")
lightmem_for_summary = load_lightmem(
collection_name=sample_id,
api_key=api_key,
Expand Down Expand Up @@ -465,40 +470,40 @@ def process_single_sample(sample, api_key, args):
logger.info(f"SUMMARY: {sample_id}")
logger.info(f"{'='*70}")

logger.info(f"\n[Storage Information]")
logger.info("\n[Storage Information]")
logger.info(f" Pre-update: {QDRANT_PRE_UPDATE_DIR}/{sample_id} ({pre_update_count} entries)")
logger.info(f" Post-update: {QDRANT_POST_UPDATE_DIR}/{sample_id} ({post_update_count} entries)")
logger.info(f" Change: {post_update_count - pre_update_count:+d} entries")
logger.info(f" Summaries: {num_summaries}")

logger.info(f"\n[Time Statistics]")
logger.info("\n[Time Statistics]")
logger.info(f" Total: {case_total_duration:.2f}s")
logger.info(f" ├─ Add: {add_memory_duration:.2f}s ({add_memory_duration/case_total_duration*100:.1f}%)")
if args.enable_summary:
logger.info(f" ├─ Summary: {summarize_duration:.2f}s ({summarize_duration/case_total_duration*100:.1f}%)")
logger.info(f" ├─ Backup: {backup_duration:.2f}s ({backup_duration/case_total_duration*100:.1f}%)")
logger.info(f" └─ Update: {update_duration:.2f}s ({update_duration/case_total_duration*100:.1f}%)")

logger.info(f"\n[Token Statistics - Add Memory]")
logger.info("\n[Token Statistics - Add Memory]")
logger.info(f" Calls: {case_add_calls}")
logger.info(f" Prompt: {case_add_prompt:,}")
logger.info(f" Completion: {case_add_completion:,}")
logger.info(f" Total: {case_add_tokens:,}")

if args.enable_summary:
logger.info(f"\n[Token Statistics - Summarize]")
logger.info("\n[Token Statistics - Summarize]")
logger.info(f" Calls: {case_summarize_calls}")
logger.info(f" Prompt: {case_summarize_prompt:,}")
logger.info(f" Completion: {case_summarize_completion:,}")
logger.info(f" Total: {case_summarize_tokens:,}")

logger.info(f"\n[Token Statistics - Update]")
logger.info("\n[Token Statistics - Update]")
logger.info(f" Calls: {case_update_calls}")
logger.info(f" Prompt: {case_update_prompt:,}")
logger.info(f" Completion: {case_update_completion:,}")
logger.info(f" Total: {case_update_tokens:,}")

logger.info(f"\n[Total Usage]")
logger.info("\n[Total Usage]")
logger.info(f" API Calls: {case_add_calls + case_summarize_calls + case_update_calls}")
logger.info(f" Tokens: {case_add_tokens + case_summarize_tokens + case_update_tokens:,}")
logger.info(f"{'='*70}\n")
Expand Down Expand Up @@ -667,7 +672,7 @@ def main():

successful = [r for r in results if r['status'] == 'success']

main_logger.info(f"\n[Overall Statistics]")
main_logger.info("\n[Overall Statistics]")
main_logger.info(f" Total samples: {len(missing)}")
main_logger.info(f" Successful: {len(successful)}")
main_logger.info(f" Failed: {len(failed_samples)}")
Expand All @@ -679,15 +684,15 @@ def main():
total_calls = sum(r['add_calls'] + r.get('summarize_calls', 0) + r['update_calls'] for r in successful)
total_summaries = sum(r.get('num_summaries', 0) for r in successful)

main_logger.info(f"\n[Performance Metrics]")
main_logger.info("\n[Performance Metrics]")
main_logger.info(f" Avg per sample: {avg_duration:.2f}s")
main_logger.info(f" Speedup: {avg_duration * len(successful) / total_duration:.2f}x")
main_logger.info(f" Total API calls: {total_calls}")
main_logger.info(f" Total tokens: {total_tokens:,}")
main_logger.info(f" Total summaries: {total_summaries}")

if failed_samples:
main_logger.info(f"\n[Failed Samples]")
main_logger.info("\n[Failed Samples]")
for sample_id in failed_samples:
main_logger.info(f" - {sample_id}")

Expand Down
5 changes: 3 additions & 2 deletions experiments/locomo/llm_judge.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import argparse
import json
import re
from collections import defaultdict
import re

import numpy as np
from openai import OpenAI


def extract_json(text):
"""
Expand Down
12 changes: 5 additions & 7 deletions experiments/locomo/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from openai import OpenAI
import logging
import os
import time
import sqlite3
import pickle
from typing import List, Dict, Any, Set, Optional, Tuple
from collections import defaultdict, deque
import sqlite3
from datetime import datetime
from typing import Any, Dict, List

import numpy as np
import spacy
from lightmem.factory.retriever.embeddingretriever.qdrant import Qdrant

from lightmem.configs.retriever.embeddingretriever.qdrant import QdrantConfig
from lightmem.factory.retriever.embeddingretriever.qdrant import Qdrant

SPACY_AVAILABLE = True
logger = logging.getLogger(__name__)
Expand Down
27 changes: 13 additions & 14 deletions experiments/locomo/search_locomo.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
from openai import OpenAI
import json
from tqdm import tqdm
import argparse
import datetime
import time
import os
import json
import logging
from typing import List, Dict, Any, Optional
import numpy as np
import argparse

from lightmem.factory.text_embedder.huggingface import TextEmbedderHuggingface
from lightmem.factory.text_embedder.openai import TextEmbedderOpenAI
from lightmem.configs.text_embedder.base_config import BaseTextEmbedderConfig
import os
import time
from typing import Dict, List, Optional

import numpy as np
from llm_judge import evaluate_llm_judge
from openai import OpenAI
from prompts import ANSWER_PROMPT, ANSWER_PROMPT_StructMem
from retrievers import QdrantEntryLoader, VectorRetriever, format_related_memories
from llm_judge import evaluate_llm_judge
from tqdm import tqdm

from lightmem.configs.text_embedder.base_config import BaseTextEmbedderConfig
from lightmem.factory.text_embedder.huggingface import TextEmbedderHuggingface
from lightmem.factory.text_embedder.openai import TextEmbedderOpenAI

# ============ Configuration ============
LOGS_ROOT = "./logs"
Expand Down Expand Up @@ -589,7 +588,7 @@ def main():
logger.info("=" * 80)
logger.info("LightMem Evaluation - LoCoMo Dataset")
logger.info("=" * 80)
logger.info(f"Configuration:")
logger.info("Configuration:")
logger.info(f" Dataset: {args.dataset}")
logger.info(f" Qdrant dir: {args.qdrant_dir}")
logger.info(f" Output dir: {args.output_dir}")
Expand Down
2 changes: 2 additions & 0 deletions experiments/longmemeval/offline_update.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os

from lightmem.memory.lightmem import LightMemory


def load_lightmem(collection_name):
config = {
"memory_manager": {
Expand Down
8 changes: 5 additions & 3 deletions experiments/longmemeval/run_lightmem_gpt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from openai import OpenAI
import json
from tqdm import tqdm
import datetime
import time

from openai import OpenAI
from tqdm import tqdm

from lightmem.memory.lightmem import LightMemory


def get_anscheck_prompt(task, question, answer, response, abstention=False):
if not abstention:
if task in ['single-session-user', 'single-session-assistant', 'multi-session']:
Expand Down
9 changes: 5 additions & 4 deletions experiments/longmemeval/run_lightmem_qwen.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from openai import OpenAI
import json
from tqdm import tqdm
import datetime
import time
import os
import time

from openai import OpenAI
from tqdm import tqdm

from lightmem.memory.lightmem import LightMemory

# ============ API Configuration ============
Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,14 @@ where = ["src"]

[tool.setuptools.package-data]
"*" = ["*.yaml", "*.yml", "*.json"]

[tool.ruff]
line-length = 120
extend-exclude = ["*.ipynb"]

[tool.ruff.lint]
select = ["E", "F", "I"]
ignore = ["E501", "E402", "E731", "E722", "E741", "E721", "E701", "F841", "F403", "F405", "F601", "F821", "E712"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
15 changes: 8 additions & 7 deletions src/lightmem/configs/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import os
from pydantic import BaseModel, Field, model_validator
from typing import Any, Dict, Optional, Literal
from pydantic import ValidationError
from lightmem.configs.pre_compressor.base import PreCompressorConfig
from lightmem.configs.topic_segmenter.base import TopicSegmenterConfig
from typing import Literal, Optional

from pydantic import BaseModel, Field

from lightmem.configs.logging.base import LoggingConfig
from lightmem.configs.memory_manager.base import MemoryManagerConfig
from lightmem.configs.text_embedder.base import TextEmbedderConfig
from lightmem.configs.multimodal_embedder.base import MMEmbedderConfig
from lightmem.configs.pre_compressor.base import PreCompressorConfig
from lightmem.configs.retriever.contextretriever.base import ContextRetrieverConfig
from lightmem.configs.retriever.embeddingretriever.base import EmbeddingRetrieverConfig
from lightmem.configs.logging.base import LoggingConfig
from lightmem.configs.text_embedder.base import TextEmbedderConfig
from lightmem.configs.topic_segmenter.base import TopicSegmenterConfig

lightmem_dir = ""

Expand Down
6 changes: 4 additions & 2 deletions src/lightmem/configs/logging/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import Optional, Union, Dict, List, Literal
import logging
import os
from datetime import datetime
from typing import Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Field, field_validator, model_validator


class LoggingConfig(BaseModel):
"""Logging configuration for LightMem."""
Expand Down
2 changes: 1 addition & 1 deletion src/lightmem/configs/logging/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Union, Dict, List, TYPE_CHECKING
from typing import TYPE_CHECKING, Dict, List, Optional, Union

if TYPE_CHECKING:
from .base import LoggingConfig
Expand Down
4 changes: 3 additions & 1 deletion src/lightmem/configs/memory_manager/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import ClassVar, List, Optional

from pydantic import BaseModel, Field, model_validator
from typing import Dict, Optional, Type, Any, List, ClassVar

from .base_config import BaseMemoryManagerConfig


Expand Down
2 changes: 1 addition & 1 deletion src/lightmem/configs/memory_manager/base_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Union, List
from typing import Dict, List, Optional, Union


class BaseMemoryManagerConfig:
Expand Down
4 changes: 3 additions & 1 deletion src/lightmem/configs/multimodal_embedder/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Optional

from pydantic import BaseModel, Field, model_validator
from typing import Dict, Optional, Type, Any, List


class MMEmbedderConfig(BaseModel):
model_name: str = Field(description="The multimodal embedding model or Deployment platform (e.g., 'openai', 'ollama')", default="huggingface")
Expand Down
6 changes: 4 additions & 2 deletions src/lightmem/configs/pre_compressor/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pydantic import BaseModel, Field, model_validator
from typing import Dict, Optional, Type, Any, Union, ClassVar
from importlib import import_module
from typing import Any, ClassVar, Dict

from pydantic import BaseModel, Field, model_validator


class PreCompressorConfig(BaseModel):
model_name: str = Field(
Expand Down
Loading