diff --git a/BSM/Downloader/__init__.py b/BSM/Downloader/__init__.py new file mode 100644 index 0000000..5a849c1 --- /dev/null +++ b/BSM/Downloader/__init__.py @@ -0,0 +1,3 @@ +from .downloader import * + +__all__ = ['HCADownloader', 'SCPDownloader'] \ No newline at end of file diff --git a/BSM/Downloader/downloader.py b/BSM/Downloader/downloader.py new file mode 100644 index 0000000..db1e150 --- /dev/null +++ b/BSM/Downloader/downloader.py @@ -0,0 +1,222 @@ +import os +import aiohttp +from tqdm.asyncio import tqdm +import asyncio +import sqlite3 +import json +import aiofiles +import requests + +class BaseDownloader: + def __init__(self, database_path, table_name, save_root): + self.database_path = database_path + self.table_name = table_name + self.save_root = save_root + + def create_connection(self): + return sqlite3.connect(self.database_path) + + async def check_file_exists(self, file_path): + return os.path.exists(file_path) + + def get_response_headers(self, url): + try: + headers = { + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" + } + response = requests.get(url, headers=headers, allow_redirects=False) + return response.headers + except Exception as e: + print(f"An error occurred: {e}") + return {} + + async def download_file(self, session, url, save_dir, semaphore, progress=None, overall_progress=None): + raise NotImplementedError("This method should be overridden by subclasses") + +class Downloader(BaseDownloader): + def __init__(self, database_path, table_name, save_root, num_workers=1, downloader_type='hca', timeout=7200, **kwargs): + super().__init__(database_path, table_name, save_root) + self.num_workers = num_workers + self.downloader_type = downloader_type + self.timeout = timeout + if downloader_type == 'hca': + self.dcp = kwargs.get('dcp') + self.cookie = None + elif downloader_type == 'scp': + self.cookie = self._process_cookie(kwargs.get('cookie')) + self.dcp = None + elif downloader_type == 'cxg': + self.cookie = None + self.dcp = None + else: + raise ValueError(f"unsupported database type: {downloader_type}") + @staticmethod + def _process_cookie(cookie): + """ + Process the cookie parameter. It can be a dictionary or a path to a JSON file. + If it's a path, load and return the dictionary. + Otherwise, raise a ValueError. + """ + if isinstance(cookie, dict): + return cookie + elif isinstance(cookie, str) and os.path.isfile(cookie): + try: + with open(cookie, 'r') as f: + return json.load(f) + except Exception as e: + raise ValueError(f"Failed to load cookie from JSON file at {cookie}: {e}") + else: + raise ValueError(f"Invalid cookie format. Expected a dictionary or a valid JSON file path, got {type(cookie)}") + async def download_file(self, session, url, save_dir, semaphore, progress=None, overall_progress=None): + async with semaphore: + try: + headers = { + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" + } + + if self.downloader_type == 'hca': + if self.dcp is not None: + url = url.replace('dcp44', self.dcp) + final_url = self.get_response_headers(url).get('Location', url) + file_path = os.path.join(save_dir, final_url.split('/')[-1].split('?')[0]) + elif self.downloader_type == 'scp': # scp + url = url.replace('/api/v1/site/studies', '/data/public') + final_url = url + file_path = os.path.join(save_dir, final_url.split('/')[-1].split('?')[-1].replace('filename=','')) + elif self.downloader_type == 'cxg': + final_url = url + file_path = os.path.join(save_dir, final_url.split('/')[-1].split('?')[-1]) + else: + final_url = url + file_path = os.path.join(save_dir, final_url.split('/')[-1].split('?')[-1]) + if await self.check_file_exists(file_path): + wrote = os.path.getsize(file_path) + headers['Range'] = f'bytes={wrote}-' + print(f"Resuming download of {file_path} from byte {wrote}") + else: + wrote = 0 + + request_kwargs = {'headers': headers, 'timeout': self.timeout} + if self.downloader_type == 'scp' and self.cookie: + request_kwargs['cookies'] = self.cookie + + async with session.get(final_url, **request_kwargs) as response: + if response.status == 416: + print(f"The local copy of {file_path} is complete.") + if overall_progress: + overall_progress.update(1) + else: + total_size = None + if response.status == 206: # Partial Content + content_range = response.headers.get('Content-Range') + if content_range: + total_size = int(content_range.partition('/')[-1].strip()) + elif response.status == 200: # OK - whole file + total_size = int(response.headers.get('Content-Length', 0)) + + if total_size is None: + print("Warning: Unable to determine total size. Downloading without progress tracking.") + pbar = tqdm(unit='B', unit_scale=True, desc=f"Downloading {file_path}", leave=False) + else: + pbar = tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {file_path}", + initial=wrote, leave=False) + + async with aiofiles.open(file_path, 'ab') as f: + async for chunk in response.content.iter_chunked(1024 * 1024): + await f.write(chunk) + wrote += len(chunk) + pbar.update(len(chunk)) + + pbar.close() + + if total_size is not None and wrote != total_size: + print( + f"ERROR: Incomplete download detected for {file_path}. Expected {total_size} bytes, got {wrote}.") + else: + print(f"Successfully downloaded {file_path}") + + except Exception as e: + print(f"Error occurred while downloading {url}: {e}") + + finally: + if overall_progress: + overall_progress.update(1) + + async def main(self): + conn = self.create_connection() + cursor = conn.cursor() + cursor.execute(f"SELECT internal_id as id, download_links as link FROM {self.table_name} WHERE download_links IS NOT NULL") + links = cursor.fetchall() + conn.close() + + semaphore = asyncio.Semaphore(self.num_workers) + tasks = [] + total_files = 0 + + # 计算总文件数 + for item in links: + id, link_json = item + try: + link_data = json.loads(link_json) + if isinstance(link_data, (dict, list)): + total_files += len(link_data) + except json.JSONDecodeError as e: + print(f"Error decoding JSON for ID {id}: {e}") + + with tqdm(total=total_files, desc="Overall Progress") as overall_progress: + async with aiohttp.ClientSession() as session: + selected_links = links[:] if self.downloader_type == 'scp' else links + + for item in selected_links: + id, link_json = item + try: + link_data = json.loads(link_json) + save_dir = os.path.join(self.save_root, str(id)) + os.makedirs(save_dir, exist_ok=True) + + if isinstance(link_data, dict): + for key, link in link_data.items(): + if isinstance(link, str) and link.startswith(('https://', 'ftp://')): + task = self.download_file(session, link, save_dir, semaphore, overall_progress=overall_progress) + tasks.append(task) + elif isinstance(link_data, list): + for link in link_data: + if isinstance(link, str) and link.startswith(('https://', 'ftp://', 'https://storage')): + task = self.download_file(session, link, save_dir, semaphore, overall_progress=overall_progress) + tasks.append(task) + else: + print(f"Unsupported data type for ID {id}: Expected dict or list, got {type(link_data)}") + + except json.JSONDecodeError as e: + print(f"Error decoding JSON for ID {id}: {e}") + + await asyncio.gather(*tasks) + +if __name__ == "__main__": + # HCA下载器示例 + hca_downloader = Downloader( + database_path="path/to/your/database.db", + table_name="your_table_name", + save_root="path/to/save/directory", + downloader_type="hca", + num_workers=4, # 并发数 + dcp="your_dcp_value", # HCA特定参数 + timeout=7200 # 可选的超时设置 + ) + + # SCP下载器示例 + scp_downloader = Downloader( + database_path="path/to/your/database.db", + table_name="your_table_name", + save_root="path/to/save/directory", + downloader_type="scp", + num_workers=4, # 并发数 + cookie={"your": "cookie"}, # SCP特定参数 + timeout=7200 # 可选的超时设置 + ) + + # 运行下载器 + asyncio.run(hca_downloader.main()) # 或者 + asyncio.run(scp_downloader.main()) \ No newline at end of file diff --git a/BSM/Fetcher/SingleCellDBs/cellxgene.py b/BSM/Fetcher/SingleCellDBs/cellxgene.py index e3622e3..c4310b2 100644 --- a/BSM/Fetcher/SingleCellDBs/cellxgene.py +++ b/BSM/Fetcher/SingleCellDBs/cellxgene.py @@ -13,12 +13,14 @@ def __init__(self, domain_name="cellxgene.cziscience.com/curation/v1"): self.headers = {"Content-Type": "application/json"} def fetch_dataset(self): + self.logger.info('fetching all cellxgene datasets') res = requests.get(url=self.datasets_url, headers=self.headers) res.raise_for_status() data = res.json() return data def fetch_collections(self): + self.logger.info('fetching all cellxgene collections') res = requests.get(url=self.collections_url, headers=self.headers) res.raise_for_status() data = res.json() @@ -27,7 +29,6 @@ def fetch_collections(self): def fetch(self, db_name): collections = self.fetch_collections() datasets = self.fetch_dataset() - merged_datasets = [] for collection in collections: collection_datasets = collection.get('datasets', []) @@ -40,5 +41,5 @@ def fetch(self, db_name): json_manager = JsonManager(db_name) json_manager.save(merged_datasets) - self.logger.info("Data saved successfully to JSON file.") + self.logger.info(f"Data saved successfully to {db_name} file.") diff --git a/BSM/Fetcher/SingleCellDBs/exploredata.py b/BSM/Fetcher/SingleCellDBs/exploredata.py index bd6061f..95867c5 100644 --- a/BSM/Fetcher/SingleCellDBs/exploredata.py +++ b/BSM/Fetcher/SingleCellDBs/exploredata.py @@ -10,14 +10,13 @@ class ExploreDataFetcher(SingleCellDBFetcher): - def __init__(self, project_url=r'https://service.azul.data.humancellatlas.org/index/projects?size=100&catalog=dcp44&order=asc&sort=projectTitle&filters=%7B%7D', - files_url=r'https://service.azul.data.humancellatlas.org/index/files'): + def __init__(self, project_url=None, files_url=None, dcp_num='dcp44'): super().__init__() self.project_meta_data = [] self.project_meta_data_with_url = [] - self.project_url = project_url - self.files_url = files_url - self.dcp_num = "dcp44" + self.dcp_num = dcp_num + self.project_url = rf'https://service.azul.data.humancellatlas.org/index/projects?size=100&catalog={self.dcp_num}&order=asc&sort=projectTitle&filters=%7B%7D' if project_url is None else project_url + self.files_url = r'https://service.azul.data.humancellatlas.org/index/files' if files_url is None else files_url self.headers = {'Accept': 'application/json, text/plain, */*'} def fetch(self, file_name): @@ -48,14 +47,14 @@ def fetch_project(self): response = requests.get(url) response.raise_for_status() - # 解析JSON数据 + data = response.json() total = data['pagination']['total'] with tqdm(total=total, desc='Fetching Project Data', initial=data['pagination']['count']) as pbar: while url: hits = data.get('hits', []) self.project_meta_data.extend(hits) - # 获取下一页的URL + pagination = data.get('pagination', {}) url = pagination.get('next', None) if url: @@ -86,9 +85,7 @@ def fetch_url(self, projects): with tqdm(total=total, desc='Fetching Project URLs', initial=data['pagination']['count']) as pbar: while url: - # 提取“hits”字段 hits = data.get('hits', []) - # 处理每个hit的files字段 for hit in hits: files = hit.get('files', []) for file in files: @@ -98,7 +95,6 @@ def fetch_url(self, projects): file.pop('format') aggregated_data[file_format].append(file) - # 获取下一页的URL pagination = data.get('pagination', {}) url = pagination.get('next', None) if url: diff --git a/BSM/Fetcher/SingleCellDBs/fetchers.py b/BSM/Fetcher/SingleCellDBs/fetchers.py index 1b36584..07ec43f 100644 --- a/BSM/Fetcher/SingleCellDBs/fetchers.py +++ b/BSM/Fetcher/SingleCellDBs/fetchers.py @@ -4,7 +4,6 @@ class SingleCellDBFetcher(object): def __init__(self): - # 设置日志记录器 self.logger = logging.getLogger(__name__) def fetch(self, db_name): diff --git a/BSM/Fetcher/SingleCellDBs/single_cell_portal.py b/BSM/Fetcher/SingleCellDBs/single_cell_portal.py index c87a56c..49f2ec6 100644 --- a/BSM/Fetcher/SingleCellDBs/single_cell_portal.py +++ b/BSM/Fetcher/SingleCellDBs/single_cell_portal.py @@ -1,4 +1,6 @@ import requests +from tqdm import tqdm + from BSM.Fetcher.SingleCellDBs.fetchers import SingleCellDBFetcher from BSM.Fetcher.utils import JsonManager @@ -15,18 +17,20 @@ def fetch(self, db_name): if response.status_code == 200: studies = response.json() final_data = [] - for study in studies: + for study in tqdm(studies): accessions = study.get('accession', 'N/A') study_url = f"{self.datasets_url}/{accessions}" response = requests.get(study_url, headers=self.headers, verify=False) if response.status_code == 200: study_data = response.json() final_data.append(study_data) - self.logger.info(f"Data saved successfully to {accessions}.json file.") + # self.logger.info() + tqdm.write(f"Data saved successfully to {accessions}.json file.") else: self.logger.error(f"Failed to retrieve study {accessions}. Status code: {response.status_code}") manager = JsonManager(db_name) manager.save(final_data) + self.logger.info(f"Data saved successfully to {db_name} file.") else: self.logger.error(f"Failed to retrieve studies. Status code: {response.status_code}") diff --git a/BSM/OSSConnector/OSSConnector.py b/BSM/OSSConnector/OSSConnector.py new file mode 100644 index 0000000..38cc16c --- /dev/null +++ b/BSM/OSSConnector/OSSConnector.py @@ -0,0 +1,65 @@ +import logging +import os +import oss2 +import sys +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +class ZJlabOSSConnector: + def __init__(self, access_key, secret, endpoint): + self.auth = oss2.Auth(access_key, secret) + self.endpoint = endpoint + + # Other methods remain unchanged... + def upload_file(self, bucket_name, remote_fp, local_fp, progress_callback=None): + """Upload a file to OSS with optional progress callback.""" + bucket = oss2.Bucket(self.auth, self.endpoint, bucket_name) + with open(local_fp, 'rb') as fileobj: + # Use the put_object method with the progress_callback parameter. + result = bucket.put_object(remote_fp, fileobj, progress_callback=progress_callback) + return result + def sync_folder(self, bucket_name, local_dir, remote_dir): + """ + Sync a local directory with a directory on OSS using object_exists to check for file existence. + :param bucket_name: The name of the bucket on OSS + :param local_dir: Path to the local directory + :param remote_dir: Path to the remote directory (excluding bucket name) + """ + bucket = oss2.Bucket(self.auth, self.endpoint, bucket_name) + + # Traverse the local directory and sync files + for root, dirs, files in os.walk(local_dir): + for file_name in files: + local_path = os.path.join(root, file_name) + relative_path = os.path.relpath(local_path, local_dir).replace('\\', '/') + remote_path = f'{remote_dir}/{relative_path}' + + try: + exist = bucket.object_exists(remote_path) + except oss2.exceptions.NoSuchKey: + exist = False + except oss2.exceptions.NoSuchBucket as e: + logging.error(f'Bucket {bucket_name} does not exist or is inaccessible: {e}') + continue + except Exception as e: + logging.debug(f'Error checking existence of {relative_path}: {e}') + exist = False + + if not exist: + logging.info(f'New file {relative_path}. Uploading...') + try: + # Pass the percentage function as the progress_callback. + self.upload_file(bucket_name, remote_path, local_path, progress_callback=self.percentage) + logging.info(f'Successfully uploaded {relative_path}') + except Exception as e: + logging.error(f'Failed to upload {relative_path}: {e}') + else: + logging.info(f'File {relative_path} already exists. Skipping...') + + @staticmethod + def percentage(consumed_bytes, total_bytes): + """Callback function for showing upload progress.""" + if total_bytes: + rate = int(100 * (float(consumed_bytes) / float(total_bytes))) + print('\r{0}% '.format(rate), end='') + sys.stdout.flush() diff --git a/BSM/OSSConnector/__init__.py b/BSM/OSSConnector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/BSM/Processors/ProjectMetadataExtractor.py b/BSM/Processors/ProjectMetadataExtractor.py index 34d3041..4e3b31b 100644 --- a/BSM/Processors/ProjectMetadataExtractor.py +++ b/BSM/Processors/ProjectMetadataExtractor.py @@ -288,5 +288,5 @@ def special_prompt(): desc_normal = f"" desc_cxg = f"Let's start with the basic information about the input data, which contains metadata about 1 project, corresponding to 1 specified doi, with 1 or more datasets in the project. The 'geo_id' information can be found from the 'link_name' of the 'links' in the 'datasets', note that only the id of the geo is needed for 'geo_id' field, the data of dataset_id should be put into 'other_ids' fields.." desc_hca = f"Let's start with the basic information about the input data, which contains metadata about 1 project, corresponding to 1 or more doi." - desc_scp = f"Let's start with the basic information about the input data, which contains metadata about 1 study. IDs starts with 'SCP' should be put into 'other_ids' fields. If the 'name' field contains content, it should be treated as the title of the project." + desc_scp = f"Let's start with the basic information about the input data, which contains metadata about 1 study. IDs starts with 'SCP' should be put into 'other_ids' fields. If the 'name' field contains content, it should be treated as the title of the project. The download_links appear in value of the key 'download_url' and 'media_url'" return {"normal": desc_normal, "cxg": desc_cxg, "hca": desc_hca, "scp": desc_scp} diff --git a/BSM/Retriever/__init__.py b/BSM/Retriever/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/BSM/Retriever/open_ai_chat_customized.py b/BSM/Retriever/open_ai_chat_customized.py new file mode 100644 index 0000000..6d87e62 --- /dev/null +++ b/BSM/Retriever/open_ai_chat_customized.py @@ -0,0 +1,130 @@ +import os + +from openai import OpenAI +from vanna.base import VannaBase + + +class OpenAI_Chat(VannaBase): + def __init__(self, client=None, config=None): + VannaBase.__init__(self, config=config) + + # default parameters - can be overrided using config + self.temperature = 0.7 + + if "temperature" in config: + self.temperature = config["temperature"] + + if "api_type" in config: + raise Exception( + "Passing api_type is now deprecated. Please pass an OpenAI client instead." + ) + + if "api_base" in config: + raise Exception( + "Passing api_base is now deprecated. Please pass an OpenAI client instead." + ) + + if "api_version" in config: + raise Exception( + "Passing api_version is now deprecated. Please pass an OpenAI client instead." + ) + + if client is not None: + self.client = client + return + + if config is None and client is None: + self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + return + + if "api_key" in config: + self.client = OpenAI(api_key=config["api_key"]) + + if "base_url" in config: + self.client.base_url = config["base_url"] + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def submit_prompt(self, prompt, **kwargs) -> str: + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + # Count the number of tokens in the message log + # Use 4 as an approximation for the number of characters per token + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + + if kwargs.get("model", None) is not None: + model = kwargs.get("model", None) + print( + f"Using model {model} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + model=model, + messages=prompt, + stop=None, + temperature=self.temperature, + ) + elif kwargs.get("engine", None) is not None: + engine = kwargs.get("engine", None) + print( + f"Using model {engine} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + engine=engine, + messages=prompt, + stop=None, + temperature=self.temperature, + ) + elif self.config is not None and "engine" in self.config: + print( + f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + engine=self.config["engine"], + messages=prompt, + stop=None, + temperature=self.temperature, + ) + elif self.config is not None and "model" in self.config: + print( + f"Using model {self.config['model']} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + model=self.config["model"], + messages=prompt, + stop=None, + temperature=self.temperature, + ) + else: + if num_tokens > 3500: + model = "gpt-3.5-turbo-16k" + else: + model = "gpt-3.5-turbo" + + print(f"Using model {model} for {num_tokens} tokens (approx)") + response = self.client.chat.completions.create( + model=model, + messages=prompt, + stop=None, + temperature=self.temperature, + ) + + # Find the first response from the chatbot that has text in it (some responses may not have text) + for choice in response.choices: + if "text" in choice: + return choice.text + + # If no response with text is found, return the first response's content (which may be empty) + return response.choices[0].message.content diff --git a/BSM/Retriever/vanna_backend.py b/BSM/Retriever/vanna_backend.py new file mode 100644 index 0000000..4eed966 --- /dev/null +++ b/BSM/Retriever/vanna_backend.py @@ -0,0 +1,121 @@ +from vanna.chromadb import ChromaDB_VectorStore +import pandas as pd + +from BSM.Retriever.open_ai_chat_customized import OpenAI_Chat + +class BSMVanna(ChromaDB_VectorStore, OpenAI_Chat): + def __init__(self, config=None): + ChromaDB_VectorStore.__init__(self, config=config) + + OpenAI_Chat.__init__(self, config=config) + + def convert_sqlite_df_to_standard(self,df, table_name): + """ + Converts a SQLite PRAGMA table_info DataFrame to a standard format compatible with INFORMATION_SCHEMA.COLUMNS. + + Args: + df (pd.DataFrame): The input DataFrame from PRAGMA table_info. + table_name (str): The name of the table. + + Returns: + pd.DataFrame: A DataFrame with columns similar to INFORMATION_SCHEMA.COLUMNS. + """ + # Map SQLite column names to standard column names + mapping = { + "cid": "ordinal_position", + "name": "column_name", + "type": "data_type", + "notnull": "is_nullable", + "dflt_value": "column_default", + "pk": "primary_key" + } + + # Rename columns based on the mapping + df = df.rename(columns=mapping) + + # Add missing columns required by the original function + df["table_catalog"] = "main" # SQLite doesn't have a catalog concept; use "main" + df["table_schema"] = "main" # SQLite doesn't have schemas; use "main" + df["table_name"] = table_name + df["is_nullable"] = df["is_nullable"].apply(lambda x: "NO" if x else "YES") + df["comment"] = None # SQLite doesn't support comments on columns + + # Reorder columns to match the expected format + standard_columns = [ + "table_catalog", "table_schema", "table_name", + "column_name", "data_type", "is_nullable", + "column_default", "ordinal_position", "comment" + ] + + return df[standard_columns] + +class BSMVannaWrapper: + def __init__(self, api_key: str, db_path: str, model='gpt-4o', base_url='https://api.openai.com/v1/'): + """ + Initialize the wrapper class for MyVanna. + + Args: + api_key (str): The API key for OpenAI. + db_path (str): The path to the SQLite database file. + """ + # Initialize configuration + self.config = { + 'api_key': api_key, + 'model': model, + 'base_url': base_url + } + + # Initialize MyVanna instance + self.vn = BSMVanna(config=self.config) + self.vn.connect_to_sqlite(db_path) + + def train(self, table_name: str): + """ + Train the model on the schema of a specific table using the training plan. + + Args: + table_name (str): The name of the table to train on. + """ + # Get the table information schema + df_information_schema = self.vn.run_sql(f"PRAGMA table_info({table_name})") + df_information_schema = self.vn.convert_sqlite_df_to_standard(df_information_schema, table_name) + + # Generate and execute the training plan + plan = self.vn.get_training_plan_generic(df_information_schema) + self.vn.train(plan=plan) + + def ask(self, question: str, table='Sample') -> tuple: + """ + Ask a question and get the SQL query and result DataFrame. + + Args: + question (str): The user's question. + + Returns: + tuple: A tuple containing the SQL query (str) and the result DataFrame (pd.DataFrame). + """ + self.train(table) + sql, df, fig = self.vn.ask(question=question, visualize=False, allow_llm_to_see_data=True) + return sql, df + +# if __name__ == "__main__": +# # Initialize the wrapper class with API parameters and database path +# wrapper = BSMVannaWrapper( +# api_key='sk-jxxxxxxxxxxxxxxxxxxxxxx', +# db_path=r'path/to/your/xx.db' +# ) +# +# # Train on a specific table +# wrapper.train('Sample') +# +# # Ask a question +# sql, df = wrapper.ask( +# question="What's the internal_id corresponding to GSE204684? " +# "The column geo_ids may contain more than one ID." +# ) +# +# # Print results +# print("Generated SQL Query:") +# print(sql) +# print("\nResult DataFrame:") +# print(df) \ No newline at end of file diff --git a/README.md b/README.md index d567f1f..7ac3cbf 100644 --- a/README.md +++ b/README.md @@ -10,4 +10,128 @@ Fetch, process and manage metadata and data samples for following databases: - [Broad Institue - single cell portal](https://singlecell.broadinstitute.org/single_cell) - +singlecelldb + +## Installation + +``` +pip install -r requirements.txt +``` + +## Usage + +### Fetchers -- Fetch meta data + +```angular2html +# Fetch from Single Cell Portal +python cli.py fetch --database scp --output scp_data.json + +# Fetch from Human Cell Atlas +python cli.py fetch --database hca --output hca_data.json + +# Fetch from CellxGene +python cli.py fetch --database cxg --output cxg_data.json + +``` + +### Processors -- Alignment +```angular2html +python cli.py process \ + --source scp \ + --input scp_data.json \ + --output-dir output/processed \ + --database processed_data.db \ + --schema DBS/json_schema.xlsx \ + --api-url your-api-url \ + --api-key your-api-key \ + --model gpt-4o + +# Advanced usage with custom parameters +python cli.py process \ + --source hca \ + --input data/hca_metadata.json \ + --output-dir output/hca_processed \ + --database projects.db \ + --schema custom_schema.json \ + --api-url "https://custom-api.example.com/v1/" \ + --api-key "your-api-key" \ + --model "custom-model" \ + --batch-size 10 \ + --workers 8 \ + --log-file logs/processing.log +``` +### Downloaders -- Download samples +``` +python cli.py download \ + --type scp \ + --database path/to/database.db \ + --table your_table \ + --save-dir test_downloader \ + --workers 1 \ + --timeout 7200 \ + --cookie path/to/cookie.json +``` + + ### Vanna -- Text to SQL + +``` +python cli.py retrieve \ + --query "What's the title corresponding to GSE204684? The column geo_ids may contain more than one ID." \ + --api-key "your_api_key" \ + --model "gpt-4o" \ + --db-path "path/to/your/xx.db" \ + --table "Sample" +``` + + + +## Evaluation + +#### task 1 Data entry accuracy 数据入库准确率 + +对给定的json格式存储的meta数据进行清洗存入数据库,再读取,恢复为原 json,对指定的 json key进行比较,如一致则为成功 + +| 指标及说明 | 正确率 - accuracy (入库正确的样本/成功入库样本数) | 失败率 - failure rate (入库失败,如json格式解析错误,不合法字符串引起的入库失败样本/总样本数) | +| ----------- | ------------------------------------------------- | ------------------------------------------------------------ | +| kimi | | | +| qwen | | | +| gpt4-o | | | +| deepseek v3 | | | + +数据源:不需要额外标注数据源,可直接用fetcher获得的数据测试 + +| 数据源 | 样本数 | +| ------------------------------------------------------------ | ------ | +| [Cellxgene](https://cellxgene.cziscience.com/datasets) | 100 | +| [Human Cell Atlas (data explorer)](https://explore.data.humancellatlas.org/projects) | 100 | +| [Broad Institue - single cell portal](https://singlecell.broadinstitute.org/single_cell) | 100 | + +#### task 2 Data Cleaning Quality Assessment + +对齐后的样本与人工标注样本比较 + +| 指标及说明 | Field-level accuracy (入库正确的itm/总item) | **Overall Sample Accuracy** (正确入库的行/总行数) | missing rate (模型未填写的item/总item) | +| ----------- | ------------------------------------------- | ------------------------------------------------- | -------------------------------------- | +| kimi | | | | +| qwen | | | | +| gpt4-o | | | | +| deepseek v3 | | | | + +已收集人类标注样本共63个 + +| 数据源 | 样本数 | +| ------------------------------------------------------------ | ------ | +| [Cellxgene](https://cellxgene.cziscience.com/datasets) | 21 | +| [Human Cell Atlas (data explorer)](https://explore.data.humancellatlas.org/projects) | 21 | +| [Broad Institue - single cell portal](https://singlecell.broadinstitute.org/single_cell) | 21 | + +#### task 3 Vanna text to sql 准确率 + +设计100个检索任务(query, 如 What's the title corresponding to GSE204684?,可以把列名跟描述给模型,让模型提出query),判断返回值是否符合预期 + +| | sql 生成率 (生成sql数/query数) | sql可执行率(可成功执行sql/query数) | 成功率(执行sql返回结果符合预期/总query) | +| ----------- | ------------------------------ | ------------------------------------ | ---------------------------------------- | +| kimi | | | | +| qwen | | | | +| gpt4-o | | | | +| deepseek v3 | | | | diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..13d7f23 --- /dev/null +++ b/cli.py @@ -0,0 +1,232 @@ +import argparse +import asyncio +import json +import logging +import math +import os + +from tqdm import tqdm + +from BSM.Fetcher.SingleCellDBs import SingleCellPortalFetcher +from BSM.Fetcher.SingleCellDBs import ExploreDataFetcher +from BSM.Fetcher.SingleCellDBs import CellxgeneFetcher +from BSM.Downloader.downloader import Downloader +from BSM.DataController.data_controller import SampleController +from BSM.Processors.ProjectMetadataExtractor import ProjectMetadataExtractor, source_info +import pandas as pd +from BSM.Retriever.vanna_backend import BSMVannaWrapper + + +def read_excel_file(file_path): + df = pd.read_excel(file_path, header=0) + data = df.to_dict(orient='records') + return data + + +def process_metadata(args): + # Setup logging + logging.basicConfig( + filename=args.log_file, + level=logging.ERROR, + format='%(asctime)s:%(levelname)s:%(message)s' + ) + + # Initialize extractor and controller + extractor = ProjectMetadataExtractor( + args.source, + args.base_url, # 使用统一的 base-url + args.api_key, + args.model, + json_schema=read_excel_file(args.schema) + ) + controller = SampleController(args.database) + + # Read input data + logging.info('loading') + with open(args.input, 'r', encoding='utf-8') as f: + input_metadata_list = json.load(f) + + # Process in batches + batch_size = args.batch_size + num_batches = math.ceil(len(input_metadata_list) / batch_size) + sum_token_usage = sum_input_token = sum_output_token = 0 + failed_tasks_all_batches = [] + + for i in tqdm(range(num_batches), desc="Processing Batches", unit="batch"): + start_index = i * batch_size + end_index = min((i + 1) * batch_size, len(input_metadata_list)) + batch = input_metadata_list[start_index:end_index] + + results, failed_tasks = extractor.extract_batch(batch, max_workers=args.workers) + + # Log failed tasks + for task in failed_tasks: + task_num = batch_size * i + task + 1 + logging.error(f"Failed task {task} in batch {i + 1}: No {task_num}") + failed_tasks_all_batches.append(task_num) + os.makedirs(args.output_dir, exist_ok=True) + # Process results + for j, result in enumerate(results): + task_id, content = result + result_data, token_usage = extractor.post_process_data(content) + + # Update token counts + sum_input_token += token_usage['input_tokens'] + sum_output_token += token_usage['output_tokens'] + sum_token_usage += token_usage['total_tokens'] + + # Save result and update database + original_task_id = start_index + task_id + result_json_path = f"{args.output_dir}/{args.source}_{original_task_id + 1:06d}.json" + with open(result_json_path, 'w', encoding='utf-8') as f: + json.dump(result_data, f, ensure_ascii=False, indent=4) + + res = controller.insert_sample(result_data) + print(f'Task {original_task_id} status: {res.get("status")}') + + # Print summary + print("Failed tasks (original numbers):", failed_tasks_all_batches) + print(f"Token usage - Total: {sum_token_usage}, Input: {sum_input_token}, Output: {sum_output_token}") + + +def retrieve_query(args): + wrapper = BSMVannaWrapper( + api_key=args.api_key, + db_path=args.db_path, + model=args.model, + base_url=args.base_url + ) + sql, df = wrapper.ask(question=args.query, table=args.table) + print("Generated SQL Query:") + print(sql) + print("\nResult DataFrame:") + print(df.to_numpy().tolist()) + + +def main(): + parser = argparse.ArgumentParser(description='Single Cell Data Management Tool') + parser.add_argument('--version', action='version', version='Data Management CLI 1.0.0') + + # 创建通用参数组 + common_parser = argparse.ArgumentParser(add_help=False) + common_parser.add_argument('--api-key', required=True, help='API key for the language model') + common_parser.add_argument('--model', default='gpt-4o', help='Language model name') + common_parser.add_argument('--base-url', default='https://api.openai.com/v1/', + help='API base URL for the language model') + + # Create main subparsers for different modules + subparsers = parser.add_subparsers(dest='module', help='Available modules') + + # Download module + download_parser = subparsers.add_parser('download', help='Download management') + download_parser.add_argument('--type', choices=['hca', 'scp', 'cxg'], required=True, + help='Downloader type (hca, scp or cxg)') + download_parser.add_argument('--database', required=True, help='Database path') + download_parser.add_argument('--table', required=True, help='Table name') + download_parser.add_argument('--save-dir', required=True, help='Save directory') + download_parser.add_argument('--workers', type=int, default=1, help='Number of parallel downloads') + download_parser.add_argument('--timeout', type=int, default=7200, help='Download timeout in seconds') + download_parser.add_argument('--dcp', help='DCP value for HCA downloader') + download_parser.add_argument('--cookie', help='Cookie file path (JSON format) for SCP downloader') + + # Fetch module + fetch_parser = subparsers.add_parser('fetch', help='Data fetching') + fetch_parser.add_argument('--database', choices=['scp', 'hca', 'cxg'], required=True, + help='Database to fetch from (scp: Single Cell Portal, hca: Human Cell Atlas, cxg: CellxGene)') + fetch_parser.add_argument('--output', required=True, help='Output JSON file path') + fetch_parser.add_argument('--domain', help='Custom domain name (optional)') + fetch_parser.add_argument('--dcp', help='DCP server address (optional)') + + # Add metadata processing module + process_parser = subparsers.add_parser('process', help='Process metadata', parents=[common_parser]) + process_parser.add_argument('--source', required=True, choices=['scp', 'hca', 'cxg'], + help='Source database type') + process_parser.add_argument('--input', required=True, + help='Input JSON file containing metadata') + process_parser.add_argument('--output-dir', required=True, + help='Output directory for processed JSON files') + process_parser.add_argument('--database', required=True, + help='SQLite database path') + process_parser.add_argument('--schema', required=True, + help='JSON schema file path') + process_parser.add_argument('--batch-size', type=int, default=5, + help='Number of items to process in each batch') + process_parser.add_argument('--workers', type=int, default=5, + help='Number of parallel workers') + process_parser.add_argument('--log-file', default='process.log', + help='Log file path') + + # Add Vanna query module + vanna_parser = subparsers.add_parser('vanna', help='Query database using Vanna AI', parents=[common_parser]) + vanna_parser.add_argument('--db-path', required=True, help='SQLite database path') + vanna_parser.add_argument('--question', required=True, help='Question to ask the database') + vanna_parser.add_argument('--table', default='Sample', help='Table name to query') + + # Add retrieve module + retrieve_parser = subparsers.add_parser('retrieve', help='Retrieve query using Vanna AI', parents=[common_parser]) + retrieve_parser.add_argument('--query', required=True, help='Query question to ask the database') + retrieve_parser.add_argument('--db-path', required=True, help='SQLite database path') + retrieve_parser.add_argument('--table', default='Sample', help='Table name to query') + + args = parser.parse_args() + + try: + if args.module == 'download': + downloader_kwargs = { + 'database_path': args.database, + 'table_name': args.table, + 'save_root': args.save_dir, + 'downloader_type': args.type, + 'num_workers': args.workers, + 'timeout': args.timeout + } + + if args.type == 'hca' and args.dcp: + downloader_kwargs['dcp'] = args.dcp + elif args.type == 'scp' and args.cookie: + try: + with open(args.cookie, 'r') as f: + downloader_kwargs['cookie'] = json.load(f) + except (json.JSONDecodeError, FileNotFoundError) as e: + print(f"Error reading cookie file: {e}") + return 1 + + # Create and run downloader + downloader = Downloader(**downloader_kwargs) + asyncio.run(downloader.main()) + + elif args.module == 'fetch': + if args.database == 'scp': + fetcher = SingleCellPortalFetcher( + domain_name=args.domain if args.domain else "singlecell.broadinstitute.org", + ) + fetcher.fetch(args.output) + + elif args.database == 'hca': + fetcher = ExploreDataFetcher(dcp_num=args.dcp) + fetcher.fetch(args.output) + + elif args.database == 'cxg': + fetcher = CellxgeneFetcher( + domain_name=args.domain if args.domain else "cellxgene.cziscience.com/curation/v1" + ) + fetcher.fetch(args.output) + + elif args.module == 'process': + process_metadata(args) + + elif args.module == 'retrieve': + retrieve_query(args) + + else: + parser.print_help() + + except Exception as e: + print(f"Error: {str(e)}") + return 1 + + return 0 + + +if __name__ == '__main__': + exit(main()) diff --git a/examples/OSS/sync_oss.py b/examples/OSS/sync_oss.py new file mode 100644 index 0000000..1c16e8c --- /dev/null +++ b/examples/OSS/sync_oss.py @@ -0,0 +1,12 @@ +from BSM.OSSConnector.OSSConnector import ZJlabOSSConnector + +if __name__ == "__main__": + access_key ='' + secret = '' + endpoint_url = '' + connector = ZJlabOSSConnector(access_key, secret, endpoint_url) + + bucket_name = '' + local_dir = '' + remote_dir = '' + connector.sync_folder(bucket_name, local_dir, remote_dir) \ No newline at end of file diff --git a/examples/download/hca_download.py b/examples/download/hca_download.py new file mode 100644 index 0000000..03bb5b1 --- /dev/null +++ b/examples/download/hca_download.py @@ -0,0 +1,13 @@ +import asyncio +from BSM.Downloader.downloader import HCADownloader + +def start_downloading(database_path, save_root): + downloader = HCADownloader(database_path, save_root) + asyncio.run(downloader.main()) + + +if __name__ == '__main__': + database_path = r'../../DBS/projects-hca-qwen2-72b-instruct1128.db' + save_root = r'D:/zjlab/data/' + + start_downloading(database_path, save_root) \ No newline at end of file diff --git a/examples/download/hca_download_async.py b/examples/download/hca_download_async.py new file mode 100644 index 0000000..34142e2 --- /dev/null +++ b/examples/download/hca_download_async.py @@ -0,0 +1,19 @@ +import asyncio +from BSM.Downloader.downloader import HCADownloader + + +def start_downloading(database_path, table_name, save_root): + downloader = HCADownloader(database_path, table_name, save_root) + asyncio.run(downloader.main()) + + +if __name__ == '__main__': + # 示例调用,实际路径应根据需要替换 + + database_path = r'E:\projects-hca-qwen2-72b-instruct1128.db' + save_root = r'E:\backup\hca_download' + table_name = r'Sample' + + + + start_downloading(database_path, table_name, save_root) \ No newline at end of file diff --git a/examples/format_modify/rdata_h5.py b/examples/format_modify/rdata_h5.py new file mode 100644 index 0000000..6238453 --- /dev/null +++ b/examples/format_modify/rdata_h5.py @@ -0,0 +1,6 @@ +import anndata +import pyreadr +import pandas as pd + +#等待更新 + diff --git a/examples/metadata/fetch_HCA.py b/examples/metadata/fetch_HCA.py index 8f6e2e9..fd62b34 100644 --- a/examples/metadata/fetch_HCA.py +++ b/examples/metadata/fetch_HCA.py @@ -6,7 +6,7 @@ def main(): output_dir = 'D:\projects\BSM\jsons' fetcher = ExploreDataFetcher() - fetcher.fetch(os.path.join(output_dir, 'hca1127_rawjson.json')) + fetcher.fetch(os.path.join(output_dir, 'hca1217_rawjson.json')) if __name__ == '__main__': main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 19bf0e7..7596c5c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,7 @@ numpy GEOparse pysradb bs4 -ijson \ No newline at end of file +ijson +aiofiles +vanna +vanna[chromadb,openai,postgres] \ No newline at end of file