From 2afb85d13f8073112cdef45a1804c5a1f70bd417 Mon Sep 17 00:00:00 2001 From: Relaxxxxx Date: Wed, 27 Nov 2024 15:45:23 +0800 Subject: [PATCH 01/12] Basedownloader --- BSM/Downloader/__init__.py | 0 BSM/Downloader/downloader.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 BSM/Downloader/__init__.py create mode 100644 BSM/Downloader/downloader.py diff --git a/BSM/Downloader/__init__.py b/BSM/Downloader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/BSM/Downloader/downloader.py b/BSM/Downloader/downloader.py new file mode 100644 index 0000000..e9cc03a --- /dev/null +++ b/BSM/Downloader/downloader.py @@ -0,0 +1,30 @@ +import logging + +logging.basicConfig(level=logging.INFO) + + +class BaseDownloader(object): + def __init__(self, data_source): + self.logger = logging.getLogger(__name__) + self.data_source = data_source + + def execute(self, output_folder): + raise NotImplementedError("Subclasses must implement the fetch method.") + + def execute_once(self, data_, output_folder): + raise NotImplementedError("Subclasses must implement the fetch method.") + + def read_data_source(self): + return + + +class HCADownloader(BaseDownloader): + + def execute(self, output_folder): + data = self.read_data_source() + for data_ in data: + self.execute_once(data_, output_folder) + self.logger.info(f'download complete. File saved to {output_folder}') + + def execute_once(self, data_, output_folder): + pass From a7a14fcf47034ad0d6cc794c6998f6ba2473513e Mon Sep 17 00:00:00 2001 From: QicangQiu Date: Thu, 5 Dec 2024 16:54:32 +0800 Subject: [PATCH 02/12] downloader code init --- BSM/Downloader/downloader.py | 115 +++++++++++++++++++++++++----- examples/download/hca_download.py | 8 +++ 2 files changed, 106 insertions(+), 17 deletions(-) create mode 100644 examples/download/hca_download.py diff --git a/BSM/Downloader/downloader.py b/BSM/Downloader/downloader.py index e9cc03a..357e45a 100644 --- a/BSM/Downloader/downloader.py +++ b/BSM/Downloader/downloader.py @@ -1,30 +1,111 @@ import logging +import os +import json +import sqlite3 + +import requests +from pathlib import Path +from tqdm import tqdm +from urllib.parse import urlparse +from ftplib import FTP logging.basicConfig(level=logging.INFO) -class BaseDownloader(object): - def __init__(self, data_source): - self.logger = logging.getLogger(__name__) - self.data_source = data_source +class BaseDownloader: + """基础下载器类,包含公共下载功能""" + + def __init__(self, download_dir): + """初始化下载目录""" + self.download_dir = download_dir + + def _save_path(self, url): + """根据URL获取保存文件的路径""" + local_filename = os.path.basename(urlparse(url).path) + save_path = os.path.join(self.download_dir, local_filename) + return save_path + + def _create_directory(self, path): + """创建目录(如果不存在)""" + if not os.path.exists(path): + os.makedirs(path) - def execute(self, output_folder): - raise NotImplementedError("Subclasses must implement the fetch method.") + def _download_with_progress(self, url, file_path): + """下载文件并显示进度条""" + try: + if url.startswith('https'): + with requests.get(url, stream=True) as r: + r.raise_for_status() + total_size_in_bytes = int(r.headers.get('content-length', 0)) + block_size = 1024 # 1KB - def execute_once(self, data_, output_folder): - raise NotImplementedError("Subclasses must implement the fetch method.") + with open(file_path, 'wb') as f: + with tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, + desc=os.path.basename(file_path)) as pbar: + for data in r.iter_content(block_size): + f.write(data) + pbar.update(len(data)) + elif url.startswith('ftp'): + def callback(block_num, block_size, total_size): + progress_bytes = block_num * block_size + tqdm.write(f"\r下载进度: {progress_bytes / total_size * 100:.2f}%", end='') - def read_data_source(self): - return + with FTP(urlparse(url).hostname) as ftp: + ftp.login() # 默认使用匿名登录,如有需要可添加用户名和密码 + file_parts = urlparse(url).path.split('/') + remote_filename = file_parts[-1] + total_size = ftp.size(remote_filename) + with open(file_path, 'wb') as f: + ftp.retrbinary(f'RETR {remote_filename}', f.write, + callback=lambda block_num, block_size: callback(block_num, block_size, + total_size)) + tqdm.write('\n') # 换行 + else: + raise ValueError("不支持的协议: " + url) + except Exception as e: + print(f"下载文件 {url} 时出错: {e}") + +# HCADownloader类继承BaseDownloader class HCADownloader(BaseDownloader): + """特定于HCA项目的下载器类,继承自BaseDownloader""" + + def __init__(self, db_path, download_dir): + """初始化数据库路径和下载目录""" + super().__init__(download_dir) + self.db_path = db_path + + def fetch_download_links(self): + """从数据库获取下载链接""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute("SELECT internal_id, download_links FROM Sample") + rows = cursor.fetchall() + download_tasks = [] + + for row in rows: + download_links_str = row[1] + internal_id = str(row[0]) + if download_links_str is not None: + download_links_list = json.loads(download_links_str) + for link in download_links_list: + if "service." in link or "ftp." in link: + link = link.strip() + download_tasks.append((internal_id, link)) - def execute(self, output_folder): - data = self.read_data_source() - for data_ in data: - self.execute_once(data_, output_folder) - self.logger.info(f'download complete. File saved to {output_folder}') + cursor.close() + conn.close() + return download_tasks - def execute_once(self, data_, output_folder): - pass + def download_files(self): + """下载文件""" + tuples_list = self.fetch_download_links() + for id, link in tuples_list: + save_dir = os.path.join(self.download_dir, id) + self._create_directory(save_dir) + file_path = self._save_path(link) + if not os.path.exists(save_dir): + self._download_with_progress(link, file_path) + else: + print(f"文件 {file_path} 已存在,跳过下载。") \ No newline at end of file diff --git a/examples/download/hca_download.py b/examples/download/hca_download.py new file mode 100644 index 0000000..93cd6c7 --- /dev/null +++ b/examples/download/hca_download.py @@ -0,0 +1,8 @@ +from BSM.Downloader.downloader import HCADownloader + +if __name__ == "__main__": + db_path = "../../DBS/projects-hca-qwen2-72b-instruct1128.db" + download_dir = r'D:/zjlab/data/' + + downloader = HCADownloader(db_path, download_dir) + downloader.download_files() \ No newline at end of file From 59f1115c062e954def10d5947a2346058f2e7bfc Mon Sep 17 00:00:00 2001 From: QicangQiu <37721905+QicangQiu@users.noreply.github.com> Date: Sun, 8 Dec 2024 22:00:16 +0800 Subject: [PATCH 03/12] Update hca_download.py modify new example code --- examples/download/hca_download.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/download/hca_download.py b/examples/download/hca_download.py index 93cd6c7..57d7d21 100644 --- a/examples/download/hca_download.py +++ b/examples/download/hca_download.py @@ -1,8 +1,14 @@ +import asyncio from BSM.Downloader.downloader import HCADownloader -if __name__ == "__main__": - db_path = "../../DBS/projects-hca-qwen2-72b-instruct1128.db" - download_dir = r'D:/zjlab/data/' +def start_downloading(database_path, table_name, save_root): + downloader = HCADownloader(database_path, table_name, save_root) + asyncio.run(downloader.main()) - downloader = HCADownloader(db_path, download_dir) - downloader.download_files() \ No newline at end of file + +if __name__ == '__main__': + database_path = r'../../DBS/projects-hca-qwen2-72b-instruct1128.db' + table_name = 'Sample' + save_root = r'D:/zjlab/data/' # 保存路径 + + start_downloading(database_path, table_name, save_root) From 9025786afb6f392984d75f7273ea34f1e1b6aa9e Mon Sep 17 00:00:00 2001 From: QicangQiu <37721905+QicangQiu@users.noreply.github.com> Date: Sun, 8 Dec 2024 22:03:06 +0800 Subject: [PATCH 04/12] Update downloader.py async download script for https&ftp --- BSM/Downloader/downloader.py | 218 ++++++++++++++++++++--------------- 1 file changed, 122 insertions(+), 96 deletions(-) diff --git a/BSM/Downloader/downloader.py b/BSM/Downloader/downloader.py index 357e45a..1d0b0e4 100644 --- a/BSM/Downloader/downloader.py +++ b/BSM/Downloader/downloader.py @@ -1,111 +1,137 @@ -import logging import os -import json +import aiohttp +from tqdm.asyncio import tqdm +import asyncio import sqlite3 - +import json +import aiofiles import requests -from pathlib import Path -from tqdm import tqdm -from urllib.parse import urlparse -from ftplib import FTP - -logging.basicConfig(level=logging.INFO) - class BaseDownloader: - """基础下载器类,包含公共下载功能""" + def __init__(self, database_path, table_name, save_root): + self.database_path = database_path + self.table_name = table_name + self.save_root = save_root - def __init__(self, download_dir): - """初始化下载目录""" - self.download_dir = download_dir + def create_connection(self): + return sqlite3.connect(self.database_path) - def _save_path(self, url): - """根据URL获取保存文件的路径""" - local_filename = os.path.basename(urlparse(url).path) - save_path = os.path.join(self.download_dir, local_filename) - return save_path + async def check_file_exists(self, file_path): + return os.path.exists(file_path) - def _create_directory(self, path): - """创建目录(如果不存在)""" - if not os.path.exists(path): - os.makedirs(path) - - def _download_with_progress(self, url, file_path): - """下载文件并显示进度条""" + def get_response_headers(self, url): try: - if url.startswith('https'): - with requests.get(url, stream=True) as r: - r.raise_for_status() - total_size_in_bytes = int(r.headers.get('content-length', 0)) - block_size = 1024 # 1KB - - with open(file_path, 'wb') as f: - with tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, - desc=os.path.basename(file_path)) as pbar: - for data in r.iter_content(block_size): - f.write(data) - pbar.update(len(data)) - elif url.startswith('ftp'): - def callback(block_num, block_size, total_size): - progress_bytes = block_num * block_size - tqdm.write(f"\r下载进度: {progress_bytes / total_size * 100:.2f}%", end='') - - with FTP(urlparse(url).hostname) as ftp: - ftp.login() # 默认使用匿名登录,如有需要可添加用户名和密码 - file_parts = urlparse(url).path.split('/') - remote_filename = file_parts[-1] - total_size = ftp.size(remote_filename) - - with open(file_path, 'wb') as f: - ftp.retrbinary(f'RETR {remote_filename}', f.write, - callback=lambda block_num, block_size: callback(block_num, block_size, - total_size)) - - tqdm.write('\n') # 换行 - else: - raise ValueError("不支持的协议: " + url) + headers = { + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" + } + response = requests.get(url, headers=headers, allow_redirects=False) + return response.headers except Exception as e: - print(f"下载文件 {url} 时出错: {e}") + print(f"An error occurred: {e}") + return {} -# HCADownloader类继承BaseDownloader -class HCADownloader(BaseDownloader): - """特定于HCA项目的下载器类,继承自BaseDownloader""" + async def download_file(self, session, url, save_dir, semaphore, progress=None, overall_progress=None): + raise NotImplementedError("This method should be overridden by subclasses") - def __init__(self, db_path, download_dir): - """初始化数据库路径和下载目录""" - super().__init__(download_dir) - self.db_path = db_path - def fetch_download_links(self): - """从数据库获取下载链接""" - conn = sqlite3.connect(self.db_path) +class HCADownloader(BaseDownloader): + def __init__(self, database_path, table_name, save_root): + super().__init__(database_path, table_name, save_root) + + async def download_file(self, session, url, save_dir, semaphore, progress=None, overall_progress=None): + async with semaphore: # 控制并发量 + try: + if url.startswith('ftp://'): + pass # FTP链接的处理需要额外的库或方法,这里暂时不做处理 + else: + # 获取重定向后的URL + headers = self.get_response_headers(url) + url = headers.get('Location', url) + + async with session.get(url) as response: + if response.status == 200: + content_disposition = response.headers.get('Content-Disposition', '') + file_name = content_disposition.split('filename=')[-1].strip('"') or url.split('/')[-1].split('?')[0] + file_path = os.path.join(save_dir, file_name) + + if await self.check_file_exists(file_path): + print(f"File {file_name} already exists, skipping.") + if overall_progress: + overall_progress.update(1) + return + + total_size = int(response.headers.get('Content-Length', 0)) + wrote = 0 + + async with aiofiles.open(file_path, 'wb') as f: + with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, leave=False) as pbar: + async for chunk in response.content.iter_chunked(1024 * 1024): # 每次写入1MB 避免写入缓存 + await f.write(chunk) + wrote += len(chunk) + pbar.update(len(chunk)) + + if total_size != 0 and wrote != total_size: + print(f"ERROR, something went wrong downloading {file_name}") + else: + print(f"Failed to download {url}. Status code: {response.status}") + if overall_progress: + overall_progress.update(1) + except Exception as e: + print(f"Error downloading {url}: {e}") + if overall_progress: + overall_progress.update(1) + + async def main(self): + conn = self.create_connection() cursor = conn.cursor() - cursor.execute("SELECT internal_id, download_links FROM Sample") - rows = cursor.fetchall() - download_tasks = [] - - for row in rows: - download_links_str = row[1] - internal_id = str(row[0]) - if download_links_str is not None: - download_links_list = json.loads(download_links_str) - for link in download_links_list: - if "service." in link or "ftp." in link: - link = link.strip() - download_tasks.append((internal_id, link)) - - cursor.close() + cursor.execute(f"SELECT internal_id as id, download_links as link FROM {self.table_name} WHERE download_links IS NOT NULL limit 2") + links = cursor.fetchall() conn.close() - return download_tasks - - def download_files(self): - """下载文件""" - tuples_list = self.fetch_download_links() - for id, link in tuples_list: - save_dir = os.path.join(self.download_dir, id) - self._create_directory(save_dir) - file_path = self._save_path(link) - if not os.path.exists(save_dir): - self._download_with_progress(link, file_path) - else: - print(f"文件 {file_path} 已存在,跳过下载。") \ No newline at end of file + + semaphore = asyncio.Semaphore(5) # 控制并发数为5 + tasks = [] + total_files = 0 + + for item in links: + id, link_json = item + try: + link_data = json.loads(link_json) + if isinstance(link_data, dict): + total_files += len(link_data) + elif isinstance(link_data, list): + total_files += len(link_data) + except json.JSONDecodeError as e: + print(f"Error decoding JSON for ID {id}: {e}") + + with tqdm(total=total_files, desc="Overall Progress") as overall_progress: + async with aiohttp.ClientSession() as session: + for item in links: + id, link_json = item + try: + link_data = json.loads(link_json) + + if isinstance(link_data, dict): + for key, link in link_data.items(): + if isinstance(link, str) and link.startswith(('https://service', 'ftp://')): + save_dir = os.path.join(self.save_root, str(id)) + os.makedirs(save_dir, exist_ok=True) + real_link = self.get_response_headers(link).get('Location', link) + + task = self.download_file(session, real_link, save_dir, semaphore, overall_progress=overall_progress) + tasks.append(task) + elif isinstance(link_data, list): + for link in link_data: + if isinstance(link, str) and link.startswith(('https://service', 'ftp://', 'https://storage')): + save_dir = os.path.join(self.save_root, str(id)) + os.makedirs(save_dir, exist_ok=True) + + task = self.download_file(session, link, save_dir, semaphore, overall_progress=overall_progress) + tasks.append(task) + else: + print(f"Unsupported data type for ID {id}: Expected dict or list, got {type(link_data)}") + + except json.JSONDecodeError as e: + print(f"Error decoding JSON for ID {id}: {e}") + + await asyncio.gather(*tasks) From ba6db1d7512bb002e07a5032340575e42ae7640f Mon Sep 17 00:00:00 2001 From: Relaxxxxx Date: Wed, 11 Dec 2024 15:30:16 +0800 Subject: [PATCH 05/12] Basedownloader debug --- BSM/Downloader/downloader.py | 43 ++++++++++++++++--------------- examples/download/hca_download.py | 4 +-- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/BSM/Downloader/downloader.py b/BSM/Downloader/downloader.py index 357e45a..1e6715d 100644 --- a/BSM/Downloader/downloader.py +++ b/BSM/Downloader/downloader.py @@ -13,25 +13,27 @@ class BaseDownloader: - """基础下载器类,包含公共下载功能""" + """Base class for downloader, containing common download functionalities""" def __init__(self, download_dir): - """初始化下载目录""" + """Initialize the download directory""" self.download_dir = download_dir - def _save_path(self, url): - """根据URL获取保存文件的路径""" + def _save_path(self, url, save_dir=None): + """Get the path to save the file based on URL""" local_filename = os.path.basename(urlparse(url).path) - save_path = os.path.join(self.download_dir, local_filename) - return save_path + if save_dir is not None: + return os.path.join(save_dir, local_filename) + else: + return os.path.join(self.download_dir, local_filename) def _create_directory(self, path): - """创建目录(如果不存在)""" + """Create directory if it does not exist""" if not os.path.exists(path): os.makedirs(path) def _download_with_progress(self, url, file_path): - """下载文件并显示进度条""" + """Download the file with a progress bar""" try: if url.startswith('https'): with requests.get(url, stream=True) as r: @@ -48,10 +50,10 @@ def _download_with_progress(self, url, file_path): elif url.startswith('ftp'): def callback(block_num, block_size, total_size): progress_bytes = block_num * block_size - tqdm.write(f"\r下载进度: {progress_bytes / total_size * 100:.2f}%", end='') + tqdm.write(f"\rDownload progress: {progress_bytes / total_size * 100:.2f}%", end='') with FTP(urlparse(url).hostname) as ftp: - ftp.login() # 默认使用匿名登录,如有需要可添加用户名和密码 + ftp.login() # Anonymous login by default, can add username and password if necessary file_parts = urlparse(url).path.split('/') remote_filename = file_parts[-1] total_size = ftp.size(remote_filename) @@ -61,23 +63,23 @@ def callback(block_num, block_size, total_size): callback=lambda block_num, block_size: callback(block_num, block_size, total_size)) - tqdm.write('\n') # 换行 + tqdm.write('\n') # New line else: - raise ValueError("不支持的协议: " + url) + raise ValueError("Unsupported protocol: " + url) except Exception as e: - print(f"下载文件 {url} 时出错: {e}") + logging.info(f"Error downloading file {url}: {e}") + -# HCADownloader类继承BaseDownloader class HCADownloader(BaseDownloader): - """特定于HCA项目的下载器类,继承自BaseDownloader""" + """Downloader class specific to the HCA project, inheriting from BaseDownloader""" def __init__(self, db_path, download_dir): - """初始化数据库路径和下载目录""" + """Initialize database path and download directory""" super().__init__(download_dir) self.db_path = db_path def fetch_download_links(self): - """从数据库获取下载链接""" + """Fetch download links from the database""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute("SELECT internal_id, download_links FROM Sample") @@ -99,13 +101,12 @@ def fetch_download_links(self): return download_tasks def download_files(self): - """下载文件""" tuples_list = self.fetch_download_links() for id, link in tuples_list: save_dir = os.path.join(self.download_dir, id) self._create_directory(save_dir) - file_path = self._save_path(link) - if not os.path.exists(save_dir): + file_path = self._save_path(link, save_dir) # 使用新的 save_dir 参数 + if not os.path.exists(file_path): # Check if the specific file path exists self._download_with_progress(link, file_path) else: - print(f"文件 {file_path} 已存在,跳过下载。") \ No newline at end of file + logging.info(f"File {file_path} already exists, skipping download.") \ No newline at end of file diff --git a/examples/download/hca_download.py b/examples/download/hca_download.py index 93cd6c7..36d2fb3 100644 --- a/examples/download/hca_download.py +++ b/examples/download/hca_download.py @@ -1,8 +1,8 @@ from BSM.Downloader.downloader import HCADownloader if __name__ == "__main__": - db_path = "../../DBS/projects-hca-qwen2-72b-instruct1128.db" - download_dir = r'D:/zjlab/data/' + db_path = "/home/lza/BSM/DBS/projects-hca-moonshot-v1-128k1128.db" + download_dir = r'/zjbs-data/hca/' downloader = HCADownloader(db_path, download_dir) downloader.download_files() \ No newline at end of file From c62de4063d1b26e873f7239ebd4d98e9b01a2926 Mon Sep 17 00:00:00 2001 From: Relaxxxxx Date: Thu, 12 Dec 2024 14:06:45 +0800 Subject: [PATCH 06/12] merge downloader --- BSM/Downloader/downloader.py | 123 ++++++++++++++++++++++++++++++++++- 1 file changed, 120 insertions(+), 3 deletions(-) diff --git a/BSM/Downloader/downloader.py b/BSM/Downloader/downloader.py index 1e6715d..51e7efa 100644 --- a/BSM/Downloader/downloader.py +++ b/BSM/Downloader/downloader.py @@ -1,8 +1,11 @@ +import asyncio import logging import os import json import sqlite3 +import aiofiles +import aiohttp import requests from pathlib import Path from tqdm import tqdm @@ -105,8 +108,122 @@ def download_files(self): for id, link in tuples_list: save_dir = os.path.join(self.download_dir, id) self._create_directory(save_dir) - file_path = self._save_path(link, save_dir) # 使用新的 save_dir 参数 - if not os.path.exists(file_path): # Check if the specific file path exists + file_path = self._save_path(link, save_dir) + if not os.path.exists(file_path): self._download_with_progress(link, file_path) else: - logging.info(f"File {file_path} already exists, skipping download.") \ No newline at end of file + logging.info(f"File {file_path} already exists, skipping download.") + + async def _get_response_headers(self, session, url): + async with session.head(url, allow_redirects=True) as response: + return response.headers + + async def _check_file_exists(self, file_path): + return Path(file_path).exists() + + async def _async_download_file(self, session, url, save_dir, semaphore, overall_progress=None): + async with semaphore: # 控制并发量 + try: + if url.startswith('ftp://'): + pass # FTP链接的处理需要额外的库或方法,这里暂时不做处理 + else: + # 获取重定向后的URL + headers = await self._get_response_headers(session, url) + url = headers.get('Location', url) + + async with session.get(url) as response: + if response.status == 200: + content_disposition = response.headers.get('Content-Disposition', '') + file_name = content_disposition.split('filename=')[-1].strip('"') or \ + url.split('/')[-1].split('?')[0] + file_path = os.path.join(save_dir, file_name) + + total_size = int(response.headers.get('Content-Length', 0)) + + # 检查本地文件是否存在且大小是否正确 + if Path(file_path).exists(): + local_file_size = Path(file_path).stat().st_size + if local_file_size == total_size: + logging.info(f"File {file_name} already exists and is correct size, skipping.") + if overall_progress: + overall_progress.update(1) + return + else: + logging.warning(f"File {file_name} exists but size does not match, re-downloading.") + os.remove(file_path) # 移除不完整的文件 + + wrote = 0 + + async with aiofiles.open(file_path, 'wb') as f: + with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, leave=False) as pbar: + async for chunk in response.content.iter_chunked(1024 * 1024): # 每次写入1MB 避免写入缓存 + await f.write(chunk) + wrote += len(chunk) + pbar.update(len(chunk)) + + if total_size != 0 and wrote != total_size: + logging.error(f"ERROR, something went wrong downloading {file_name}") + if Path(file_path).exists(): + os.remove(file_path) # 下载失败移除部分下载的文件 + else: + logging.error(f"Failed to download {url}. Status code: {response.status}") + if overall_progress: + overall_progress.update(1) + except Exception as e: + logging.error(f"Error downloading {url}: {e}") + if overall_progress: + overall_progress.update(1) + + async def async_download_files(self, workers=5): + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute("SELECT internal_id, download_links FROM Sample WHERE download_links IS NOT NULL") + links = cursor.fetchall() + conn.close() + + semaphore = asyncio.Semaphore(workers) + tasks = [] + total_files = 0 + + for item in links: + id, link_json = item + try: + link_data = json.loads(link_json) + if isinstance(link_data, dict): + total_files += len(link_data) + elif isinstance(link_data, list): + total_files += len(link_data) + except json.JSONDecodeError as e: + logging.error(f"Error decoding JSON for ID {id}: {e}") + with tqdm(total=total_files, desc="Overall Progress") as overall_progress: + async with aiohttp.ClientSession() as session: + for item in links: + id, link_json = item + try: + link_data = json.loads(link_json) + + if isinstance(link_data, dict): + for key, link in link_data.items(): + if isinstance(link, str) and link.startswith(('https://service', 'ftp://')): + save_dir = os.path.join(self.download_dir, str(id)) + os.makedirs(save_dir, exist_ok=True) + task = self._async_download_file(session, link, save_dir, semaphore, + overall_progress=overall_progress) + tasks.append(task) + elif isinstance(link_data, list): + for link in link_data: + if isinstance(link, str) and link.startswith( + ('https://service', 'ftp://', 'https://storage')): + save_dir = os.path.join(self.download_dir, str(id)) + os.makedirs(save_dir, exist_ok=True) + task = self._async_download_file(session, link, save_dir, semaphore, + overall_progress=overall_progress) + tasks.append(task) + else: + logging.error( + f"Unsupported data type for ID {id}: Expected dict or list, got {type(link_data)}") + + except json.JSONDecodeError as e: + logging.error(f"Error decoding JSON for ID {id}: {e}") + + await asyncio.gather(*tasks) \ No newline at end of file From 375a192ac19f57d9393ef299e14d994884c8ce77 Mon Sep 17 00:00:00 2001 From: Relaxxxxx Date: Fri, 20 Dec 2024 15:59:43 +0800 Subject: [PATCH 07/12] merge downloader --- BSM/Downloader/downloader.py | 33 +++++++----- BSM/Fetcher/SingleCellDBs/exploredata.py | 4 +- BSM/Fetcher/SingleCellDBs/fetchers.py | 1 - BSM/OSSConnector/OSSConnector.py | 65 ++++++++++++++++++++++++ BSM/OSSConnector/__init__.py | 0 examples/OSS/sync_oss.py | 12 +++++ examples/metadata/fetch_HCA.py | 2 +- 7 files changed, 99 insertions(+), 18 deletions(-) create mode 100644 BSM/OSSConnector/OSSConnector.py create mode 100644 BSM/OSSConnector/__init__.py create mode 100644 examples/OSS/sync_oss.py diff --git a/BSM/Downloader/downloader.py b/BSM/Downloader/downloader.py index 51e7efa..1dbed3f 100644 --- a/BSM/Downloader/downloader.py +++ b/BSM/Downloader/downloader.py @@ -8,6 +8,8 @@ import aiohttp import requests from pathlib import Path + +from aiohttp import ClientTimeout from tqdm import tqdm from urllib.parse import urlparse from ftplib import FTP @@ -123,13 +125,15 @@ async def _check_file_exists(self, file_path): async def _async_download_file(self, session, url, save_dir, semaphore, overall_progress=None): async with semaphore: # 控制并发量 - try: + # try: if url.startswith('ftp://'): - pass # FTP链接的处理需要额外的库或方法,这里暂时不做处理 - else: - # 获取重定向后的URL - headers = await self._get_response_headers(session, url) - url = headers.get('Location', url) + if overall_progress: + overall_progress.update(1) + return # FTP链接的处理需要额外的库或方法,这里暂时不做处理 + # else: + # # 获取重定向后的URL + # headers = await self._get_response_headers(session, url) + # url = headers.get('Location', url) async with session.get(url) as response: if response.status == 200: @@ -150,13 +154,13 @@ async def _async_download_file(self, session, url, save_dir, semaphore, overall_ return else: logging.warning(f"File {file_name} exists but size does not match, re-downloading.") - os.remove(file_path) # 移除不完整的文件 + os.remove(file_path) wrote = 0 async with aiofiles.open(file_path, 'wb') as f: with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, leave=False) as pbar: - async for chunk in response.content.iter_chunked(1024 * 1024): # 每次写入1MB 避免写入缓存 + async for chunk in response.content.iter_chunked(1024 * 1024): await f.write(chunk) wrote += len(chunk) pbar.update(len(chunk)) @@ -164,15 +168,15 @@ async def _async_download_file(self, session, url, save_dir, semaphore, overall_ if total_size != 0 and wrote != total_size: logging.error(f"ERROR, something went wrong downloading {file_name}") if Path(file_path).exists(): - os.remove(file_path) # 下载失败移除部分下载的文件 + os.remove(file_path) else: logging.error(f"Failed to download {url}. Status code: {response.status}") if overall_progress: overall_progress.update(1) - except Exception as e: - logging.error(f"Error downloading {url}: {e}") - if overall_progress: - overall_progress.update(1) + # except Exception as e: + # logging.error(f"Error downloading {url}: {e}") + # if overall_progress: + # overall_progress.update(1) async def async_download_files(self, workers=5): conn = sqlite3.connect(self.db_path) @@ -196,7 +200,8 @@ async def async_download_files(self, workers=5): except json.JSONDecodeError as e: logging.error(f"Error decoding JSON for ID {id}: {e}") with tqdm(total=total_files, desc="Overall Progress") as overall_progress: - async with aiohttp.ClientSession() as session: + timeout = ClientTimeout(total=60 * 300) + async with aiohttp.ClientSession(timeout=timeout) as session: for item in links: id, link_json = item try: diff --git a/BSM/Fetcher/SingleCellDBs/exploredata.py b/BSM/Fetcher/SingleCellDBs/exploredata.py index bd6061f..65b767f 100644 --- a/BSM/Fetcher/SingleCellDBs/exploredata.py +++ b/BSM/Fetcher/SingleCellDBs/exploredata.py @@ -48,14 +48,14 @@ def fetch_project(self): response = requests.get(url) response.raise_for_status() - # 解析JSON数据 + data = response.json() total = data['pagination']['total'] with tqdm(total=total, desc='Fetching Project Data', initial=data['pagination']['count']) as pbar: while url: hits = data.get('hits', []) self.project_meta_data.extend(hits) - # 获取下一页的URL + pagination = data.get('pagination', {}) url = pagination.get('next', None) if url: diff --git a/BSM/Fetcher/SingleCellDBs/fetchers.py b/BSM/Fetcher/SingleCellDBs/fetchers.py index 1b36584..07ec43f 100644 --- a/BSM/Fetcher/SingleCellDBs/fetchers.py +++ b/BSM/Fetcher/SingleCellDBs/fetchers.py @@ -4,7 +4,6 @@ class SingleCellDBFetcher(object): def __init__(self): - # 设置日志记录器 self.logger = logging.getLogger(__name__) def fetch(self, db_name): diff --git a/BSM/OSSConnector/OSSConnector.py b/BSM/OSSConnector/OSSConnector.py new file mode 100644 index 0000000..38cc16c --- /dev/null +++ b/BSM/OSSConnector/OSSConnector.py @@ -0,0 +1,65 @@ +import logging +import os +import oss2 +import sys +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +class ZJlabOSSConnector: + def __init__(self, access_key, secret, endpoint): + self.auth = oss2.Auth(access_key, secret) + self.endpoint = endpoint + + # Other methods remain unchanged... + def upload_file(self, bucket_name, remote_fp, local_fp, progress_callback=None): + """Upload a file to OSS with optional progress callback.""" + bucket = oss2.Bucket(self.auth, self.endpoint, bucket_name) + with open(local_fp, 'rb') as fileobj: + # Use the put_object method with the progress_callback parameter. + result = bucket.put_object(remote_fp, fileobj, progress_callback=progress_callback) + return result + def sync_folder(self, bucket_name, local_dir, remote_dir): + """ + Sync a local directory with a directory on OSS using object_exists to check for file existence. + :param bucket_name: The name of the bucket on OSS + :param local_dir: Path to the local directory + :param remote_dir: Path to the remote directory (excluding bucket name) + """ + bucket = oss2.Bucket(self.auth, self.endpoint, bucket_name) + + # Traverse the local directory and sync files + for root, dirs, files in os.walk(local_dir): + for file_name in files: + local_path = os.path.join(root, file_name) + relative_path = os.path.relpath(local_path, local_dir).replace('\\', '/') + remote_path = f'{remote_dir}/{relative_path}' + + try: + exist = bucket.object_exists(remote_path) + except oss2.exceptions.NoSuchKey: + exist = False + except oss2.exceptions.NoSuchBucket as e: + logging.error(f'Bucket {bucket_name} does not exist or is inaccessible: {e}') + continue + except Exception as e: + logging.debug(f'Error checking existence of {relative_path}: {e}') + exist = False + + if not exist: + logging.info(f'New file {relative_path}. Uploading...') + try: + # Pass the percentage function as the progress_callback. + self.upload_file(bucket_name, remote_path, local_path, progress_callback=self.percentage) + logging.info(f'Successfully uploaded {relative_path}') + except Exception as e: + logging.error(f'Failed to upload {relative_path}: {e}') + else: + logging.info(f'File {relative_path} already exists. Skipping...') + + @staticmethod + def percentage(consumed_bytes, total_bytes): + """Callback function for showing upload progress.""" + if total_bytes: + rate = int(100 * (float(consumed_bytes) / float(total_bytes))) + print('\r{0}% '.format(rate), end='') + sys.stdout.flush() diff --git a/BSM/OSSConnector/__init__.py b/BSM/OSSConnector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/OSS/sync_oss.py b/examples/OSS/sync_oss.py new file mode 100644 index 0000000..1c16e8c --- /dev/null +++ b/examples/OSS/sync_oss.py @@ -0,0 +1,12 @@ +from BSM.OSSConnector.OSSConnector import ZJlabOSSConnector + +if __name__ == "__main__": + access_key ='' + secret = '' + endpoint_url = '' + connector = ZJlabOSSConnector(access_key, secret, endpoint_url) + + bucket_name = '' + local_dir = '' + remote_dir = '' + connector.sync_folder(bucket_name, local_dir, remote_dir) \ No newline at end of file diff --git a/examples/metadata/fetch_HCA.py b/examples/metadata/fetch_HCA.py index 8f6e2e9..fd62b34 100644 --- a/examples/metadata/fetch_HCA.py +++ b/examples/metadata/fetch_HCA.py @@ -6,7 +6,7 @@ def main(): output_dir = 'D:\projects\BSM\jsons' fetcher = ExploreDataFetcher() - fetcher.fetch(os.path.join(output_dir, 'hca1127_rawjson.json')) + fetcher.fetch(os.path.join(output_dir, 'hca1217_rawjson.json')) if __name__ == '__main__': main() \ No newline at end of file From 29e34106c9b7b2d15e9603d4d01161096334fd0c Mon Sep 17 00:00:00 2001 From: QicangQiu Date: Fri, 27 Dec 2024 22:55:02 +0800 Subject: [PATCH 08/12] new function of resuming upload --- BSM/Downloader/downloader.py | 305 ++++++++++-------------- examples/download/hca_download_async.py | 22 +- examples/format_modify/rdata_h5.py | 6 + 3 files changed, 142 insertions(+), 191 deletions(-) create mode 100644 examples/format_modify/rdata_h5.py diff --git a/BSM/Downloader/downloader.py b/BSM/Downloader/downloader.py index 1dbed3f..4b4fb23 100644 --- a/BSM/Downloader/downloader.py +++ b/BSM/Downloader/downloader.py @@ -1,191 +1,132 @@ -import asyncio -import logging import os -import json +import aiohttp +from tqdm.asyncio import tqdm +import asyncio import sqlite3 - +import json import aiofiles -import aiohttp import requests -from pathlib import Path -from aiohttp import ClientTimeout -from tqdm import tqdm -from urllib.parse import urlparse -from ftplib import FTP +class BaseDownloader: + def __init__(self, database_path, table_name, save_root): + self.database_path = database_path + self.table_name = table_name + self.save_root = save_root -logging.basicConfig(level=logging.INFO) + def create_connection(self): + return sqlite3.connect(self.database_path) + async def check_file_exists(self, file_path): + return os.path.exists(file_path) -class BaseDownloader: - """Base class for downloader, containing common download functionalities""" - - def __init__(self, download_dir): - """Initialize the download directory""" - self.download_dir = download_dir - - def _save_path(self, url, save_dir=None): - """Get the path to save the file based on URL""" - local_filename = os.path.basename(urlparse(url).path) - if save_dir is not None: - return os.path.join(save_dir, local_filename) - else: - return os.path.join(self.download_dir, local_filename) - - def _create_directory(self, path): - """Create directory if it does not exist""" - if not os.path.exists(path): - os.makedirs(path) - - def _download_with_progress(self, url, file_path): - """Download the file with a progress bar""" + def get_response_headers(self, url): try: - if url.startswith('https'): - with requests.get(url, stream=True) as r: - r.raise_for_status() - total_size_in_bytes = int(r.headers.get('content-length', 0)) - block_size = 1024 # 1KB - - with open(file_path, 'wb') as f: - with tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, - desc=os.path.basename(file_path)) as pbar: - for data in r.iter_content(block_size): - f.write(data) - pbar.update(len(data)) - elif url.startswith('ftp'): - def callback(block_num, block_size, total_size): - progress_bytes = block_num * block_size - tqdm.write(f"\rDownload progress: {progress_bytes / total_size * 100:.2f}%", end='') - - with FTP(urlparse(url).hostname) as ftp: - ftp.login() # Anonymous login by default, can add username and password if necessary - file_parts = urlparse(url).path.split('/') - remote_filename = file_parts[-1] - total_size = ftp.size(remote_filename) - - with open(file_path, 'wb') as f: - ftp.retrbinary(f'RETR {remote_filename}', f.write, - callback=lambda block_num, block_size: callback(block_num, block_size, - total_size)) - - tqdm.write('\n') # New line - else: - raise ValueError("Unsupported protocol: " + url) + headers = { + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" + } + response = requests.get(url, headers=headers, allow_redirects=False) + return response.headers except Exception as e: - logging.info(f"Error downloading file {url}: {e}") + print(f"An error occurred: {e}") + return {} + async def download_file(self, session, url, save_dir, semaphore, progress=None, overall_progress=None): + raise NotImplementedError("This method should be overridden by subclasses") -class HCADownloader(BaseDownloader): - """Downloader class specific to the HCA project, inheriting from BaseDownloader""" - def __init__(self, db_path, download_dir): - """Initialize database path and download directory""" - super().__init__(download_dir) - self.db_path = db_path +class SpecialDownloader(BaseDownloader): + def __init__(self, database_path, table_name, save_root): + super().__init__(database_path, table_name, save_root) - def fetch_download_links(self): - """Fetch download links from the database""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - cursor.execute("SELECT internal_id, download_links FROM Sample") - rows = cursor.fetchall() - download_tasks = [] - - for row in rows: - download_links_str = row[1] - internal_id = str(row[0]) - if download_links_str is not None: - download_links_list = json.loads(download_links_str) - for link in download_links_list: - if "service." in link or "ftp." in link: - link = link.strip() - download_tasks.append((internal_id, link)) - - cursor.close() - conn.close() - return download_tasks - - def download_files(self): - tuples_list = self.fetch_download_links() - for id, link in tuples_list: - save_dir = os.path.join(self.download_dir, id) - self._create_directory(save_dir) - file_path = self._save_path(link, save_dir) - if not os.path.exists(file_path): - self._download_with_progress(link, file_path) - else: - logging.info(f"File {file_path} already exists, skipping download.") - - async def _get_response_headers(self, session, url): - async with session.head(url, allow_redirects=True) as response: - return response.headers - - async def _check_file_exists(self, file_path): - return Path(file_path).exists() - - async def _async_download_file(self, session, url, save_dir, semaphore, overall_progress=None): + async def download_file(self, session, url, save_dir, semaphore, progress=None, overall_progress=None): async with semaphore: # 控制并发量 - # try: - if url.startswith('ftp://'): - if overall_progress: - overall_progress.update(1) - return # FTP链接的处理需要额外的库或方法,这里暂时不做处理 - # else: - # # 获取重定向后的URL - # headers = await self._get_response_headers(session, url) - # url = headers.get('Location', url) - - async with session.get(url) as response: - if response.status == 200: - content_disposition = response.headers.get('Content-Disposition', '') - file_name = content_disposition.split('filename=')[-1].strip('"') or \ - url.split('/')[-1].split('?')[0] - file_path = os.path.join(save_dir, file_name) - - total_size = int(response.headers.get('Content-Length', 0)) - - # 检查本地文件是否存在且大小是否正确 - if Path(file_path).exists(): - local_file_size = Path(file_path).stat().st_size - if local_file_size == total_size: - logging.info(f"File {file_name} already exists and is correct size, skipping.") - if overall_progress: - overall_progress.update(1) - return - else: - logging.warning(f"File {file_name} exists but size does not match, re-downloading.") - os.remove(file_path) - - wrote = 0 - - async with aiofiles.open(file_path, 'wb') as f: - with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, leave=False) as pbar: - async for chunk in response.content.iter_chunked(1024 * 1024): - await f.write(chunk) - wrote += len(chunk) - pbar.update(len(chunk)) - - if total_size != 0 and wrote != total_size: - logging.error(f"ERROR, something went wrong downloading {file_name}") - if Path(file_path).exists(): - os.remove(file_path) + try: + headers = { + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" + } + url=self.get_response_headers(url).get('Location', url) + file_path = os.path.join(save_dir, url.split('/')[-1].split('?')[0]) + #print(url) + #print('路径:',file_path) + + if await self.check_file_exists(file_path): + wrote = os.path.getsize(file_path) + headers['Range'] = f'bytes={wrote}-' + print(f"Resuming download of {file_path} from byte {wrote}") + else: + wrote = 0 + + async with session.get(url, headers=headers) as response: + #print(url) + if response.status == 416: + print(f"The local file is large enough to not require downloading {file_path}") else: - logging.error(f"Failed to download {url}. Status code: {response.status}") - if overall_progress: - overall_progress.update(1) - # except Exception as e: - # logging.error(f"Error downloading {url}: {e}") - # if overall_progress: - # overall_progress.update(1) - - async def async_download_files(self, workers=5): - conn = sqlite3.connect(self.db_path) + if response.status == 206: # Partial Content + # print(response.headers) + content_range = response.headers.get('Content-Range', '') + if content_range: + total_size = int(content_range.partition('/')[-1].strip()) + else: + print("Warning: No Content-Range header found.") + total_size = None + elif response.status == 200: # OK - whole file + content_length = response.headers.get('Content-Length') + total_size = int(content_length) if content_length and content_length.isdigit() else None + + else: + print(f"Failed to download {url}. Status code: {response.status}") + if overall_progress: + overall_progress.update(1) + return + + if total_size is None: + print("Warning: Unable to determine total size. Downloading without progress tracking.") + pbar = tqdm(unit='B', unit_scale=True, desc=file_path, leave=False) + else: + pbar = tqdm(total=total_size, unit='B', unit_scale=True, desc=file_path, initial=wrote, + leave=False) + + async with aiofiles.open(file_path, 'ab') as f: # Open in append binary mode + async for chunk in response.content.iter_chunked(1024 * 1024): # 每次写入1MB 避免写入缓存 + await f.write(chunk) + wrote += len(chunk) + pbar.update(len(chunk)) + + pbar.close() + + if total_size is not None and wrote != total_size: + print(f"ERROR, something went wrong downloading {file_path}") + else: + print(f"Successfully downloaded {file_path}") + + if overall_progress: + overall_progress.update(1) + + + except Exception as e: + print(f"Error downloading {url}: {e}") + if overall_progress: + overall_progress.update(1) + + except Exception as e: + print(f"Error downloading {url}: {e}") + if overall_progress: + overall_progress.update(1) + + + + + + async def main(self): + conn = self.create_connection() cursor = conn.cursor() - cursor.execute("SELECT internal_id, download_links FROM Sample WHERE download_links IS NOT NULL") + cursor.execute(f"SELECT internal_id as id, download_links as link FROM {self.table_name} WHERE download_links IS NOT NULL") links = cursor.fetchall() conn.close() - semaphore = asyncio.Semaphore(workers) + semaphore = asyncio.Semaphore(5) # 控制并发数为5 tasks = [] total_files = 0 @@ -198,10 +139,10 @@ async def async_download_files(self, workers=5): elif isinstance(link_data, list): total_files += len(link_data) except json.JSONDecodeError as e: - logging.error(f"Error decoding JSON for ID {id}: {e}") + print(f"Error decoding JSON for ID {id}: {e}") + with tqdm(total=total_files, desc="Overall Progress") as overall_progress: - timeout = ClientTimeout(total=60 * 300) - async with aiohttp.ClientSession(timeout=timeout) as session: + async with aiohttp.ClientSession() as session: for item in links: id, link_json = item try: @@ -210,25 +151,25 @@ async def async_download_files(self, workers=5): if isinstance(link_data, dict): for key, link in link_data.items(): if isinstance(link, str) and link.startswith(('https://service', 'ftp://')): - save_dir = os.path.join(self.download_dir, str(id)) + save_dir = os.path.join(self.save_root, str(id)) os.makedirs(save_dir, exist_ok=True) - task = self._async_download_file(session, link, save_dir, semaphore, - overall_progress=overall_progress) + #real_link = self.get_response_headers(link).get('Location', link) + + task = self.download_file(session, link, save_dir, semaphore, overall_progress=overall_progress) tasks.append(task) elif isinstance(link_data, list): for link in link_data: - if isinstance(link, str) and link.startswith( - ('https://service', 'ftp://', 'https://storage')): - save_dir = os.path.join(self.download_dir, str(id)) + if isinstance(link, str) and link.startswith(('https://service', 'ftp://', 'https://storage')): + save_dir = os.path.join(self.save_root, str(id)) os.makedirs(save_dir, exist_ok=True) - task = self._async_download_file(session, link, save_dir, semaphore, - overall_progress=overall_progress) + #real_link = self.get_response_headers(link).get('Location', link) + + task = self.download_file(session, link, save_dir, semaphore, overall_progress=overall_progress) tasks.append(task) else: - logging.error( - f"Unsupported data type for ID {id}: Expected dict or list, got {type(link_data)}") + print(f"Unsupported data type for ID {id}: Expected dict or list, got {type(link_data)}") except json.JSONDecodeError as e: - logging.error(f"Error decoding JSON for ID {id}: {e}") + print(f"Error decoding JSON for ID {id}: {e}") await asyncio.gather(*tasks) \ No newline at end of file diff --git a/examples/download/hca_download_async.py b/examples/download/hca_download_async.py index 9f531d0..6485782 100644 --- a/examples/download/hca_download_async.py +++ b/examples/download/hca_download_async.py @@ -1,15 +1,19 @@ import asyncio -from BSM.Downloader.downloader import HCADownloader +from BSM.Downloader.downloader import SpecialDownloader -def start_downloading(database_path, table_name): - downloader = HCADownloader(database_path, table_name) - async def run_downloader(): - await downloader.async_download_files() - asyncio.run(run_downloader()) + +def start_downloading(database_path, table_name, save_root): + downloader = SpecialDownloader(database_path, table_name, save_root) + asyncio.run(downloader.main()) if __name__ == '__main__': - database_path = r'../../DBS/projects-hca-qwen2-72b-instruct1128.db' - save_root = r'D:\backup\hca_download' + # 示例调用,实际路径应根据需要替换 + + database_path = r'E:\projects-hca-qwen2-72b-instruct1128.db' + save_root = r'E:\backup\hca_download' + table_name = r'Sample' + + - start_downloading(database_path, save_root) \ No newline at end of file + start_downloading(database_path, table_name, save_root) \ No newline at end of file diff --git a/examples/format_modify/rdata_h5.py b/examples/format_modify/rdata_h5.py new file mode 100644 index 0000000..6238453 --- /dev/null +++ b/examples/format_modify/rdata_h5.py @@ -0,0 +1,6 @@ +import anndata +import pyreadr +import pandas as pd + +#等待更新 + From ebc8ffdce5ce58a24ef215ae27a78240927da841 Mon Sep 17 00:00:00 2001 From: QicangQiu Date: Fri, 27 Dec 2024 23:02:40 +0800 Subject: [PATCH 09/12] fix func name --- BSM/Downloader/downloader.py | 2 +- examples/download/hca_download_async.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/BSM/Downloader/downloader.py b/BSM/Downloader/downloader.py index 4b4fb23..245e1d3 100644 --- a/BSM/Downloader/downloader.py +++ b/BSM/Downloader/downloader.py @@ -35,7 +35,7 @@ async def download_file(self, session, url, save_dir, semaphore, progress=None, raise NotImplementedError("This method should be overridden by subclasses") -class SpecialDownloader(BaseDownloader): +class HCADownloader(BaseDownloader): def __init__(self, database_path, table_name, save_root): super().__init__(database_path, table_name, save_root) diff --git a/examples/download/hca_download_async.py b/examples/download/hca_download_async.py index 6485782..34142e2 100644 --- a/examples/download/hca_download_async.py +++ b/examples/download/hca_download_async.py @@ -1,9 +1,9 @@ import asyncio -from BSM.Downloader.downloader import SpecialDownloader +from BSM.Downloader.downloader import HCADownloader def start_downloading(database_path, table_name, save_root): - downloader = SpecialDownloader(database_path, table_name, save_root) + downloader = HCADownloader(database_path, table_name, save_root) asyncio.run(downloader.main()) From 4af52d10a7626ff5cd34a8bdfaa1755ce7c2c747 Mon Sep 17 00:00:00 2001 From: Relaxxxxx Date: Fri, 14 Feb 2025 10:12:12 +0800 Subject: [PATCH 10/12] merge downloader --- BSM/Downloader/downloader.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/BSM/Downloader/downloader.py b/BSM/Downloader/downloader.py index 245e1d3..bb12c16 100644 --- a/BSM/Downloader/downloader.py +++ b/BSM/Downloader/downloader.py @@ -36,20 +36,22 @@ async def download_file(self, session, url, save_dir, semaphore, progress=None, class HCADownloader(BaseDownloader): - def __init__(self, database_path, table_name, save_root): + def __init__(self, database_path, table_name, save_root, num_workers=1, dcp=None): super().__init__(database_path, table_name, save_root) + self.num_workers = num_workers + self.dcp = dcp async def download_file(self, session, url, save_dir, semaphore, progress=None, overall_progress=None): - async with semaphore: # 控制并发量 + async with semaphore: try: headers = { "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" } + if self.dcp is not None: + url = url.replace('dcp44', self.dcp) url=self.get_response_headers(url).get('Location', url) file_path = os.path.join(save_dir, url.split('/')[-1].split('?')[0]) - #print(url) - #print('路径:',file_path) if await self.check_file_exists(file_path): wrote = os.path.getsize(file_path) @@ -59,12 +61,10 @@ async def download_file(self, session, url, save_dir, semaphore, progress=None, wrote = 0 async with session.get(url, headers=headers) as response: - #print(url) if response.status == 416: - print(f"The local file is large enough to not require downloading {file_path}") + print(f"The local file is large enough for: {file_path}") else: if response.status == 206: # Partial Content - # print(response.headers) content_range = response.headers.get('Content-Range', '') if content_range: total_size = int(content_range.partition('/')[-1].strip()) @@ -115,10 +115,6 @@ async def download_file(self, session, url, save_dir, semaphore, progress=None, if overall_progress: overall_progress.update(1) - - - - async def main(self): conn = self.create_connection() cursor = conn.cursor() @@ -126,7 +122,7 @@ async def main(self): links = cursor.fetchall() conn.close() - semaphore = asyncio.Semaphore(5) # 控制并发数为5 + semaphore = asyncio.Semaphore(self.num_workers) tasks = [] total_files = 0 @@ -153,7 +149,6 @@ async def main(self): if isinstance(link, str) and link.startswith(('https://service', 'ftp://')): save_dir = os.path.join(self.save_root, str(id)) os.makedirs(save_dir, exist_ok=True) - #real_link = self.get_response_headers(link).get('Location', link) task = self.download_file(session, link, save_dir, semaphore, overall_progress=overall_progress) tasks.append(task) @@ -162,7 +157,6 @@ async def main(self): if isinstance(link, str) and link.startswith(('https://service', 'ftp://', 'https://storage')): save_dir = os.path.join(self.save_root, str(id)) os.makedirs(save_dir, exist_ok=True) - #real_link = self.get_response_headers(link).get('Location', link) task = self.download_file(session, link, save_dir, semaphore, overall_progress=overall_progress) tasks.append(task) From d226eb8a55a0864176867d0255cb396b0a886f8a Mon Sep 17 00:00:00 2001 From: Relaxxxxx Date: Thu, 10 Apr 2025 16:53:29 +0800 Subject: [PATCH 11/12] command line api update --- BSM/Downloader/__init__.py | 3 + BSM/Downloader/downloader.py | 165 ++++++++++++------ BSM/Fetcher/SingleCellDBs/cellxgene.py | 5 +- BSM/Fetcher/SingleCellDBs/exploredata.py | 12 +- .../SingleCellDBs/single_cell_portal.py | 8 +- BSM/Processors/ProjectMetadataExtractor.py | 2 +- README.md | 51 ++++++ 7 files changed, 177 insertions(+), 69 deletions(-) diff --git a/BSM/Downloader/__init__.py b/BSM/Downloader/__init__.py index e69de29..5a849c1 100644 --- a/BSM/Downloader/__init__.py +++ b/BSM/Downloader/__init__.py @@ -0,0 +1,3 @@ +from .downloader import * + +__all__ = ['HCADownloader', 'SCPDownloader'] \ No newline at end of file diff --git a/BSM/Downloader/downloader.py b/BSM/Downloader/downloader.py index bb12c16..db1e150 100644 --- a/BSM/Downloader/downloader.py +++ b/BSM/Downloader/downloader.py @@ -34,25 +34,63 @@ def get_response_headers(self, url): async def download_file(self, session, url, save_dir, semaphore, progress=None, overall_progress=None): raise NotImplementedError("This method should be overridden by subclasses") - -class HCADownloader(BaseDownloader): - def __init__(self, database_path, table_name, save_root, num_workers=1, dcp=None): +class Downloader(BaseDownloader): + def __init__(self, database_path, table_name, save_root, num_workers=1, downloader_type='hca', timeout=7200, **kwargs): super().__init__(database_path, table_name, save_root) self.num_workers = num_workers - self.dcp = dcp - + self.downloader_type = downloader_type + self.timeout = timeout + if downloader_type == 'hca': + self.dcp = kwargs.get('dcp') + self.cookie = None + elif downloader_type == 'scp': + self.cookie = self._process_cookie(kwargs.get('cookie')) + self.dcp = None + elif downloader_type == 'cxg': + self.cookie = None + self.dcp = None + else: + raise ValueError(f"unsupported database type: {downloader_type}") + @staticmethod + def _process_cookie(cookie): + """ + Process the cookie parameter. It can be a dictionary or a path to a JSON file. + If it's a path, load and return the dictionary. + Otherwise, raise a ValueError. + """ + if isinstance(cookie, dict): + return cookie + elif isinstance(cookie, str) and os.path.isfile(cookie): + try: + with open(cookie, 'r') as f: + return json.load(f) + except Exception as e: + raise ValueError(f"Failed to load cookie from JSON file at {cookie}: {e}") + else: + raise ValueError(f"Invalid cookie format. Expected a dictionary or a valid JSON file path, got {type(cookie)}") async def download_file(self, session, url, save_dir, semaphore, progress=None, overall_progress=None): async with semaphore: try: headers = { - "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" - } - if self.dcp is not None: - url = url.replace('dcp44', self.dcp) - url=self.get_response_headers(url).get('Location', url) - file_path = os.path.join(save_dir, url.split('/')[-1].split('?')[0]) - + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" + } + + if self.downloader_type == 'hca': + if self.dcp is not None: + url = url.replace('dcp44', self.dcp) + final_url = self.get_response_headers(url).get('Location', url) + file_path = os.path.join(save_dir, final_url.split('/')[-1].split('?')[0]) + elif self.downloader_type == 'scp': # scp + url = url.replace('/api/v1/site/studies', '/data/public') + final_url = url + file_path = os.path.join(save_dir, final_url.split('/')[-1].split('?')[-1].replace('filename=','')) + elif self.downloader_type == 'cxg': + final_url = url + file_path = os.path.join(save_dir, final_url.split('/')[-1].split('?')[-1]) + else: + final_url = url + file_path = os.path.join(save_dir, final_url.split('/')[-1].split('?')[-1]) if await self.check_file_exists(file_path): wrote = os.path.getsize(file_path) headers['Range'] = f'bytes={wrote}-' @@ -60,36 +98,33 @@ async def download_file(self, session, url, save_dir, semaphore, progress=None, else: wrote = 0 - async with session.get(url, headers=headers) as response: - if response.status == 416: - print(f"The local file is large enough for: {file_path}") + request_kwargs = {'headers': headers, 'timeout': self.timeout} + if self.downloader_type == 'scp' and self.cookie: + request_kwargs['cookies'] = self.cookie + + async with session.get(final_url, **request_kwargs) as response: + if response.status == 416: + print(f"The local copy of {file_path} is complete.") + if overall_progress: + overall_progress.update(1) else: + total_size = None if response.status == 206: # Partial Content - content_range = response.headers.get('Content-Range', '') + content_range = response.headers.get('Content-Range') if content_range: total_size = int(content_range.partition('/')[-1].strip()) - else: - print("Warning: No Content-Range header found.") - total_size = None elif response.status == 200: # OK - whole file - content_length = response.headers.get('Content-Length') - total_size = int(content_length) if content_length and content_length.isdigit() else None - - else: - print(f"Failed to download {url}. Status code: {response.status}") - if overall_progress: - overall_progress.update(1) - return + total_size = int(response.headers.get('Content-Length', 0)) if total_size is None: print("Warning: Unable to determine total size. Downloading without progress tracking.") - pbar = tqdm(unit='B', unit_scale=True, desc=file_path, leave=False) + pbar = tqdm(unit='B', unit_scale=True, desc=f"Downloading {file_path}", leave=False) else: - pbar = tqdm(total=total_size, unit='B', unit_scale=True, desc=file_path, initial=wrote, - leave=False) + pbar = tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {file_path}", + initial=wrote, leave=False) - async with aiofiles.open(file_path, 'ab') as f: # Open in append binary mode - async for chunk in response.content.iter_chunked(1024 * 1024): # 每次写入1MB 避免写入缓存 + async with aiofiles.open(file_path, 'ab') as f: + async for chunk in response.content.iter_chunked(1024 * 1024): await f.write(chunk) wrote += len(chunk) pbar.update(len(chunk)) @@ -97,21 +132,15 @@ async def download_file(self, session, url, save_dir, semaphore, progress=None, pbar.close() if total_size is not None and wrote != total_size: - print(f"ERROR, something went wrong downloading {file_path}") + print( + f"ERROR: Incomplete download detected for {file_path}. Expected {total_size} bytes, got {wrote}.") else: print(f"Successfully downloaded {file_path}") - if overall_progress: - overall_progress.update(1) - - except Exception as e: - print(f"Error downloading {url}: {e}") - if overall_progress: - overall_progress.update(1) + print(f"Error occurred while downloading {url}: {e}") - except Exception as e: - print(f"Error downloading {url}: {e}") + finally: if overall_progress: overall_progress.update(1) @@ -126,38 +155,35 @@ async def main(self): tasks = [] total_files = 0 + # 计算总文件数 for item in links: id, link_json = item try: link_data = json.loads(link_json) - if isinstance(link_data, dict): - total_files += len(link_data) - elif isinstance(link_data, list): + if isinstance(link_data, (dict, list)): total_files += len(link_data) except json.JSONDecodeError as e: print(f"Error decoding JSON for ID {id}: {e}") with tqdm(total=total_files, desc="Overall Progress") as overall_progress: async with aiohttp.ClientSession() as session: - for item in links: + selected_links = links[:] if self.downloader_type == 'scp' else links + + for item in selected_links: id, link_json = item try: link_data = json.loads(link_json) + save_dir = os.path.join(self.save_root, str(id)) + os.makedirs(save_dir, exist_ok=True) if isinstance(link_data, dict): for key, link in link_data.items(): - if isinstance(link, str) and link.startswith(('https://service', 'ftp://')): - save_dir = os.path.join(self.save_root, str(id)) - os.makedirs(save_dir, exist_ok=True) - + if isinstance(link, str) and link.startswith(('https://', 'ftp://')): task = self.download_file(session, link, save_dir, semaphore, overall_progress=overall_progress) tasks.append(task) elif isinstance(link_data, list): for link in link_data: - if isinstance(link, str) and link.startswith(('https://service', 'ftp://', 'https://storage')): - save_dir = os.path.join(self.save_root, str(id)) - os.makedirs(save_dir, exist_ok=True) - + if isinstance(link, str) and link.startswith(('https://', 'ftp://', 'https://storage')): task = self.download_file(session, link, save_dir, semaphore, overall_progress=overall_progress) tasks.append(task) else: @@ -166,4 +192,31 @@ async def main(self): except json.JSONDecodeError as e: print(f"Error decoding JSON for ID {id}: {e}") - await asyncio.gather(*tasks) \ No newline at end of file + await asyncio.gather(*tasks) + +if __name__ == "__main__": + # HCA下载器示例 + hca_downloader = Downloader( + database_path="path/to/your/database.db", + table_name="your_table_name", + save_root="path/to/save/directory", + downloader_type="hca", + num_workers=4, # 并发数 + dcp="your_dcp_value", # HCA特定参数 + timeout=7200 # 可选的超时设置 + ) + + # SCP下载器示例 + scp_downloader = Downloader( + database_path="path/to/your/database.db", + table_name="your_table_name", + save_root="path/to/save/directory", + downloader_type="scp", + num_workers=4, # 并发数 + cookie={"your": "cookie"}, # SCP特定参数 + timeout=7200 # 可选的超时设置 + ) + + # 运行下载器 + asyncio.run(hca_downloader.main()) # 或者 + asyncio.run(scp_downloader.main()) \ No newline at end of file diff --git a/BSM/Fetcher/SingleCellDBs/cellxgene.py b/BSM/Fetcher/SingleCellDBs/cellxgene.py index e3622e3..c4310b2 100644 --- a/BSM/Fetcher/SingleCellDBs/cellxgene.py +++ b/BSM/Fetcher/SingleCellDBs/cellxgene.py @@ -13,12 +13,14 @@ def __init__(self, domain_name="cellxgene.cziscience.com/curation/v1"): self.headers = {"Content-Type": "application/json"} def fetch_dataset(self): + self.logger.info('fetching all cellxgene datasets') res = requests.get(url=self.datasets_url, headers=self.headers) res.raise_for_status() data = res.json() return data def fetch_collections(self): + self.logger.info('fetching all cellxgene collections') res = requests.get(url=self.collections_url, headers=self.headers) res.raise_for_status() data = res.json() @@ -27,7 +29,6 @@ def fetch_collections(self): def fetch(self, db_name): collections = self.fetch_collections() datasets = self.fetch_dataset() - merged_datasets = [] for collection in collections: collection_datasets = collection.get('datasets', []) @@ -40,5 +41,5 @@ def fetch(self, db_name): json_manager = JsonManager(db_name) json_manager.save(merged_datasets) - self.logger.info("Data saved successfully to JSON file.") + self.logger.info(f"Data saved successfully to {db_name} file.") diff --git a/BSM/Fetcher/SingleCellDBs/exploredata.py b/BSM/Fetcher/SingleCellDBs/exploredata.py index 65b767f..95867c5 100644 --- a/BSM/Fetcher/SingleCellDBs/exploredata.py +++ b/BSM/Fetcher/SingleCellDBs/exploredata.py @@ -10,14 +10,13 @@ class ExploreDataFetcher(SingleCellDBFetcher): - def __init__(self, project_url=r'https://service.azul.data.humancellatlas.org/index/projects?size=100&catalog=dcp44&order=asc&sort=projectTitle&filters=%7B%7D', - files_url=r'https://service.azul.data.humancellatlas.org/index/files'): + def __init__(self, project_url=None, files_url=None, dcp_num='dcp44'): super().__init__() self.project_meta_data = [] self.project_meta_data_with_url = [] - self.project_url = project_url - self.files_url = files_url - self.dcp_num = "dcp44" + self.dcp_num = dcp_num + self.project_url = rf'https://service.azul.data.humancellatlas.org/index/projects?size=100&catalog={self.dcp_num}&order=asc&sort=projectTitle&filters=%7B%7D' if project_url is None else project_url + self.files_url = r'https://service.azul.data.humancellatlas.org/index/files' if files_url is None else files_url self.headers = {'Accept': 'application/json, text/plain, */*'} def fetch(self, file_name): @@ -86,9 +85,7 @@ def fetch_url(self, projects): with tqdm(total=total, desc='Fetching Project URLs', initial=data['pagination']['count']) as pbar: while url: - # 提取“hits”字段 hits = data.get('hits', []) - # 处理每个hit的files字段 for hit in hits: files = hit.get('files', []) for file in files: @@ -98,7 +95,6 @@ def fetch_url(self, projects): file.pop('format') aggregated_data[file_format].append(file) - # 获取下一页的URL pagination = data.get('pagination', {}) url = pagination.get('next', None) if url: diff --git a/BSM/Fetcher/SingleCellDBs/single_cell_portal.py b/BSM/Fetcher/SingleCellDBs/single_cell_portal.py index c87a56c..49f2ec6 100644 --- a/BSM/Fetcher/SingleCellDBs/single_cell_portal.py +++ b/BSM/Fetcher/SingleCellDBs/single_cell_portal.py @@ -1,4 +1,6 @@ import requests +from tqdm import tqdm + from BSM.Fetcher.SingleCellDBs.fetchers import SingleCellDBFetcher from BSM.Fetcher.utils import JsonManager @@ -15,18 +17,20 @@ def fetch(self, db_name): if response.status_code == 200: studies = response.json() final_data = [] - for study in studies: + for study in tqdm(studies): accessions = study.get('accession', 'N/A') study_url = f"{self.datasets_url}/{accessions}" response = requests.get(study_url, headers=self.headers, verify=False) if response.status_code == 200: study_data = response.json() final_data.append(study_data) - self.logger.info(f"Data saved successfully to {accessions}.json file.") + # self.logger.info() + tqdm.write(f"Data saved successfully to {accessions}.json file.") else: self.logger.error(f"Failed to retrieve study {accessions}. Status code: {response.status_code}") manager = JsonManager(db_name) manager.save(final_data) + self.logger.info(f"Data saved successfully to {db_name} file.") else: self.logger.error(f"Failed to retrieve studies. Status code: {response.status_code}") diff --git a/BSM/Processors/ProjectMetadataExtractor.py b/BSM/Processors/ProjectMetadataExtractor.py index 34d3041..4e3b31b 100644 --- a/BSM/Processors/ProjectMetadataExtractor.py +++ b/BSM/Processors/ProjectMetadataExtractor.py @@ -288,5 +288,5 @@ def special_prompt(): desc_normal = f"" desc_cxg = f"Let's start with the basic information about the input data, which contains metadata about 1 project, corresponding to 1 specified doi, with 1 or more datasets in the project. The 'geo_id' information can be found from the 'link_name' of the 'links' in the 'datasets', note that only the id of the geo is needed for 'geo_id' field, the data of dataset_id should be put into 'other_ids' fields.." desc_hca = f"Let's start with the basic information about the input data, which contains metadata about 1 project, corresponding to 1 or more doi." - desc_scp = f"Let's start with the basic information about the input data, which contains metadata about 1 study. IDs starts with 'SCP' should be put into 'other_ids' fields. If the 'name' field contains content, it should be treated as the title of the project." + desc_scp = f"Let's start with the basic information about the input data, which contains metadata about 1 study. IDs starts with 'SCP' should be put into 'other_ids' fields. If the 'name' field contains content, it should be treated as the title of the project. The download_links appear in value of the key 'download_url' and 'media_url'" return {"normal": desc_normal, "cxg": desc_cxg, "hca": desc_hca, "scp": desc_scp} diff --git a/README.md b/README.md index d567f1f..f28bf4a 100644 --- a/README.md +++ b/README.md @@ -10,4 +10,55 @@ Fetch, process and manage metadata and data samples for following databases: - [Broad Institue - single cell portal](https://singlecell.broadinstitute.org/single_cell) +## Fetchers +```angular2html +# Fetch from Single Cell Portal +python cli.py fetch --database scp --output scp_data.json + +# Fetch from Human Cell Atlas +python cli.py fetch --database hca --output hca_data.json + +# Fetch from CellxGene +python cli.py fetch --database cxg --output cxg_data.json + +``` + +## Processors +```angular2html +python cli.py process \ + --source scp \ + --input scp_data.json \ + --output-dir output/processed \ + --database processed_data.db \ + --schema DBS/json_schema.xlsx \ + --api-url your-api-url \ + --api-key your-api-key \ + --model gpt-4o + +# Advanced usage with custom parameters +python cli.py process \ + --source hca \ + --input data/hca_metadata.json \ + --output-dir output/hca_processed \ + --database projects.db \ + --schema custom_schema.json \ + --api-url "https://custom-api.example.com/v1/" \ + --api-key "your-api-key" \ + --model "custom-model" \ + --batch-size 10 \ + --workers 8 \ + --log-file logs/processing.log +``` +## Downloaders +``` +python cli.py download \ + --type scp \ + --database path/to/database.db \ + --table your_table \ + --save-dir test_downloader \ + --workers 1 \ + --timeout 7200 \ + --cookie path/to/cookie.json +``` + From 8f5bf0d34ab2dc5f809e5af5de573677323f117d Mon Sep 17 00:00:00 2001 From: Relaxxxxx Date: Fri, 30 May 2025 16:40:16 +0800 Subject: [PATCH 12/12] command line api update --- BSM/Retriever/__init__.py | 0 BSM/Retriever/open_ai_chat_customized.py | 130 +++++++++++++ BSM/Retriever/vanna_backend.py | 121 ++++++++++++ README.md | 81 +++++++- cli.py | 232 +++++++++++++++++++++++ requirements.txt | 4 +- 6 files changed, 563 insertions(+), 5 deletions(-) create mode 100644 BSM/Retriever/__init__.py create mode 100644 BSM/Retriever/open_ai_chat_customized.py create mode 100644 BSM/Retriever/vanna_backend.py create mode 100644 cli.py diff --git a/BSM/Retriever/__init__.py b/BSM/Retriever/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/BSM/Retriever/open_ai_chat_customized.py b/BSM/Retriever/open_ai_chat_customized.py new file mode 100644 index 0000000..6d87e62 --- /dev/null +++ b/BSM/Retriever/open_ai_chat_customized.py @@ -0,0 +1,130 @@ +import os + +from openai import OpenAI +from vanna.base import VannaBase + + +class OpenAI_Chat(VannaBase): + def __init__(self, client=None, config=None): + VannaBase.__init__(self, config=config) + + # default parameters - can be overrided using config + self.temperature = 0.7 + + if "temperature" in config: + self.temperature = config["temperature"] + + if "api_type" in config: + raise Exception( + "Passing api_type is now deprecated. Please pass an OpenAI client instead." + ) + + if "api_base" in config: + raise Exception( + "Passing api_base is now deprecated. Please pass an OpenAI client instead." + ) + + if "api_version" in config: + raise Exception( + "Passing api_version is now deprecated. Please pass an OpenAI client instead." + ) + + if client is not None: + self.client = client + return + + if config is None and client is None: + self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + return + + if "api_key" in config: + self.client = OpenAI(api_key=config["api_key"]) + + if "base_url" in config: + self.client.base_url = config["base_url"] + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def submit_prompt(self, prompt, **kwargs) -> str: + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + # Count the number of tokens in the message log + # Use 4 as an approximation for the number of characters per token + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + + if kwargs.get("model", None) is not None: + model = kwargs.get("model", None) + print( + f"Using model {model} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + model=model, + messages=prompt, + stop=None, + temperature=self.temperature, + ) + elif kwargs.get("engine", None) is not None: + engine = kwargs.get("engine", None) + print( + f"Using model {engine} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + engine=engine, + messages=prompt, + stop=None, + temperature=self.temperature, + ) + elif self.config is not None and "engine" in self.config: + print( + f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + engine=self.config["engine"], + messages=prompt, + stop=None, + temperature=self.temperature, + ) + elif self.config is not None and "model" in self.config: + print( + f"Using model {self.config['model']} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + model=self.config["model"], + messages=prompt, + stop=None, + temperature=self.temperature, + ) + else: + if num_tokens > 3500: + model = "gpt-3.5-turbo-16k" + else: + model = "gpt-3.5-turbo" + + print(f"Using model {model} for {num_tokens} tokens (approx)") + response = self.client.chat.completions.create( + model=model, + messages=prompt, + stop=None, + temperature=self.temperature, + ) + + # Find the first response from the chatbot that has text in it (some responses may not have text) + for choice in response.choices: + if "text" in choice: + return choice.text + + # If no response with text is found, return the first response's content (which may be empty) + return response.choices[0].message.content diff --git a/BSM/Retriever/vanna_backend.py b/BSM/Retriever/vanna_backend.py new file mode 100644 index 0000000..4eed966 --- /dev/null +++ b/BSM/Retriever/vanna_backend.py @@ -0,0 +1,121 @@ +from vanna.chromadb import ChromaDB_VectorStore +import pandas as pd + +from BSM.Retriever.open_ai_chat_customized import OpenAI_Chat + +class BSMVanna(ChromaDB_VectorStore, OpenAI_Chat): + def __init__(self, config=None): + ChromaDB_VectorStore.__init__(self, config=config) + + OpenAI_Chat.__init__(self, config=config) + + def convert_sqlite_df_to_standard(self,df, table_name): + """ + Converts a SQLite PRAGMA table_info DataFrame to a standard format compatible with INFORMATION_SCHEMA.COLUMNS. + + Args: + df (pd.DataFrame): The input DataFrame from PRAGMA table_info. + table_name (str): The name of the table. + + Returns: + pd.DataFrame: A DataFrame with columns similar to INFORMATION_SCHEMA.COLUMNS. + """ + # Map SQLite column names to standard column names + mapping = { + "cid": "ordinal_position", + "name": "column_name", + "type": "data_type", + "notnull": "is_nullable", + "dflt_value": "column_default", + "pk": "primary_key" + } + + # Rename columns based on the mapping + df = df.rename(columns=mapping) + + # Add missing columns required by the original function + df["table_catalog"] = "main" # SQLite doesn't have a catalog concept; use "main" + df["table_schema"] = "main" # SQLite doesn't have schemas; use "main" + df["table_name"] = table_name + df["is_nullable"] = df["is_nullable"].apply(lambda x: "NO" if x else "YES") + df["comment"] = None # SQLite doesn't support comments on columns + + # Reorder columns to match the expected format + standard_columns = [ + "table_catalog", "table_schema", "table_name", + "column_name", "data_type", "is_nullable", + "column_default", "ordinal_position", "comment" + ] + + return df[standard_columns] + +class BSMVannaWrapper: + def __init__(self, api_key: str, db_path: str, model='gpt-4o', base_url='https://api.openai.com/v1/'): + """ + Initialize the wrapper class for MyVanna. + + Args: + api_key (str): The API key for OpenAI. + db_path (str): The path to the SQLite database file. + """ + # Initialize configuration + self.config = { + 'api_key': api_key, + 'model': model, + 'base_url': base_url + } + + # Initialize MyVanna instance + self.vn = BSMVanna(config=self.config) + self.vn.connect_to_sqlite(db_path) + + def train(self, table_name: str): + """ + Train the model on the schema of a specific table using the training plan. + + Args: + table_name (str): The name of the table to train on. + """ + # Get the table information schema + df_information_schema = self.vn.run_sql(f"PRAGMA table_info({table_name})") + df_information_schema = self.vn.convert_sqlite_df_to_standard(df_information_schema, table_name) + + # Generate and execute the training plan + plan = self.vn.get_training_plan_generic(df_information_schema) + self.vn.train(plan=plan) + + def ask(self, question: str, table='Sample') -> tuple: + """ + Ask a question and get the SQL query and result DataFrame. + + Args: + question (str): The user's question. + + Returns: + tuple: A tuple containing the SQL query (str) and the result DataFrame (pd.DataFrame). + """ + self.train(table) + sql, df, fig = self.vn.ask(question=question, visualize=False, allow_llm_to_see_data=True) + return sql, df + +# if __name__ == "__main__": +# # Initialize the wrapper class with API parameters and database path +# wrapper = BSMVannaWrapper( +# api_key='sk-jxxxxxxxxxxxxxxxxxxxxxx', +# db_path=r'path/to/your/xx.db' +# ) +# +# # Train on a specific table +# wrapper.train('Sample') +# +# # Ask a question +# sql, df = wrapper.ask( +# question="What's the internal_id corresponding to GSE204684? " +# "The column geo_ids may contain more than one ID." +# ) +# +# # Print results +# print("Generated SQL Query:") +# print(sql) +# print("\nResult DataFrame:") +# print(df) \ No newline at end of file diff --git a/README.md b/README.md index f28bf4a..7ac3cbf 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,18 @@ Fetch, process and manage metadata and data samples for following databases: - [Broad Institue - single cell portal](https://singlecell.broadinstitute.org/single_cell) -## Fetchers +singlecelldb + +## Installation + +``` +pip install -r requirements.txt +``` + +## Usage + +### Fetchers -- Fetch meta data + ```angular2html # Fetch from Single Cell Portal python cli.py fetch --database scp --output scp_data.json @@ -23,7 +34,7 @@ python cli.py fetch --database cxg --output cxg_data.json ``` -## Processors +### Processors -- Alignment ```angular2html python cli.py process \ --source scp \ @@ -49,7 +60,7 @@ python cli.py process \ --workers 8 \ --log-file logs/processing.log ``` -## Downloaders +### Downloaders -- Download samples ``` python cli.py download \ --type scp \ @@ -61,4 +72,66 @@ python cli.py download \ --cookie path/to/cookie.json ``` - + ### Vanna -- Text to SQL + +``` +python cli.py retrieve \ + --query "What's the title corresponding to GSE204684? The column geo_ids may contain more than one ID." \ + --api-key "your_api_key" \ + --model "gpt-4o" \ + --db-path "path/to/your/xx.db" \ + --table "Sample" +``` + + + +## Evaluation + +#### task 1 Data entry accuracy 数据入库准确率 + +对给定的json格式存储的meta数据进行清洗存入数据库,再读取,恢复为原 json,对指定的 json key进行比较,如一致则为成功 + +| 指标及说明 | 正确率 - accuracy (入库正确的样本/成功入库样本数) | 失败率 - failure rate (入库失败,如json格式解析错误,不合法字符串引起的入库失败样本/总样本数) | +| ----------- | ------------------------------------------------- | ------------------------------------------------------------ | +| kimi | | | +| qwen | | | +| gpt4-o | | | +| deepseek v3 | | | + +数据源:不需要额外标注数据源,可直接用fetcher获得的数据测试 + +| 数据源 | 样本数 | +| ------------------------------------------------------------ | ------ | +| [Cellxgene](https://cellxgene.cziscience.com/datasets) | 100 | +| [Human Cell Atlas (data explorer)](https://explore.data.humancellatlas.org/projects) | 100 | +| [Broad Institue - single cell portal](https://singlecell.broadinstitute.org/single_cell) | 100 | + +#### task 2 Data Cleaning Quality Assessment + +对齐后的样本与人工标注样本比较 + +| 指标及说明 | Field-level accuracy (入库正确的itm/总item) | **Overall Sample Accuracy** (正确入库的行/总行数) | missing rate (模型未填写的item/总item) | +| ----------- | ------------------------------------------- | ------------------------------------------------- | -------------------------------------- | +| kimi | | | | +| qwen | | | | +| gpt4-o | | | | +| deepseek v3 | | | | + +已收集人类标注样本共63个 + +| 数据源 | 样本数 | +| ------------------------------------------------------------ | ------ | +| [Cellxgene](https://cellxgene.cziscience.com/datasets) | 21 | +| [Human Cell Atlas (data explorer)](https://explore.data.humancellatlas.org/projects) | 21 | +| [Broad Institue - single cell portal](https://singlecell.broadinstitute.org/single_cell) | 21 | + +#### task 3 Vanna text to sql 准确率 + +设计100个检索任务(query, 如 What's the title corresponding to GSE204684?,可以把列名跟描述给模型,让模型提出query),判断返回值是否符合预期 + +| | sql 生成率 (生成sql数/query数) | sql可执行率(可成功执行sql/query数) | 成功率(执行sql返回结果符合预期/总query) | +| ----------- | ------------------------------ | ------------------------------------ | ---------------------------------------- | +| kimi | | | | +| qwen | | | | +| gpt4-o | | | | +| deepseek v3 | | | | diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..13d7f23 --- /dev/null +++ b/cli.py @@ -0,0 +1,232 @@ +import argparse +import asyncio +import json +import logging +import math +import os + +from tqdm import tqdm + +from BSM.Fetcher.SingleCellDBs import SingleCellPortalFetcher +from BSM.Fetcher.SingleCellDBs import ExploreDataFetcher +from BSM.Fetcher.SingleCellDBs import CellxgeneFetcher +from BSM.Downloader.downloader import Downloader +from BSM.DataController.data_controller import SampleController +from BSM.Processors.ProjectMetadataExtractor import ProjectMetadataExtractor, source_info +import pandas as pd +from BSM.Retriever.vanna_backend import BSMVannaWrapper + + +def read_excel_file(file_path): + df = pd.read_excel(file_path, header=0) + data = df.to_dict(orient='records') + return data + + +def process_metadata(args): + # Setup logging + logging.basicConfig( + filename=args.log_file, + level=logging.ERROR, + format='%(asctime)s:%(levelname)s:%(message)s' + ) + + # Initialize extractor and controller + extractor = ProjectMetadataExtractor( + args.source, + args.base_url, # 使用统一的 base-url + args.api_key, + args.model, + json_schema=read_excel_file(args.schema) + ) + controller = SampleController(args.database) + + # Read input data + logging.info('loading') + with open(args.input, 'r', encoding='utf-8') as f: + input_metadata_list = json.load(f) + + # Process in batches + batch_size = args.batch_size + num_batches = math.ceil(len(input_metadata_list) / batch_size) + sum_token_usage = sum_input_token = sum_output_token = 0 + failed_tasks_all_batches = [] + + for i in tqdm(range(num_batches), desc="Processing Batches", unit="batch"): + start_index = i * batch_size + end_index = min((i + 1) * batch_size, len(input_metadata_list)) + batch = input_metadata_list[start_index:end_index] + + results, failed_tasks = extractor.extract_batch(batch, max_workers=args.workers) + + # Log failed tasks + for task in failed_tasks: + task_num = batch_size * i + task + 1 + logging.error(f"Failed task {task} in batch {i + 1}: No {task_num}") + failed_tasks_all_batches.append(task_num) + os.makedirs(args.output_dir, exist_ok=True) + # Process results + for j, result in enumerate(results): + task_id, content = result + result_data, token_usage = extractor.post_process_data(content) + + # Update token counts + sum_input_token += token_usage['input_tokens'] + sum_output_token += token_usage['output_tokens'] + sum_token_usage += token_usage['total_tokens'] + + # Save result and update database + original_task_id = start_index + task_id + result_json_path = f"{args.output_dir}/{args.source}_{original_task_id + 1:06d}.json" + with open(result_json_path, 'w', encoding='utf-8') as f: + json.dump(result_data, f, ensure_ascii=False, indent=4) + + res = controller.insert_sample(result_data) + print(f'Task {original_task_id} status: {res.get("status")}') + + # Print summary + print("Failed tasks (original numbers):", failed_tasks_all_batches) + print(f"Token usage - Total: {sum_token_usage}, Input: {sum_input_token}, Output: {sum_output_token}") + + +def retrieve_query(args): + wrapper = BSMVannaWrapper( + api_key=args.api_key, + db_path=args.db_path, + model=args.model, + base_url=args.base_url + ) + sql, df = wrapper.ask(question=args.query, table=args.table) + print("Generated SQL Query:") + print(sql) + print("\nResult DataFrame:") + print(df.to_numpy().tolist()) + + +def main(): + parser = argparse.ArgumentParser(description='Single Cell Data Management Tool') + parser.add_argument('--version', action='version', version='Data Management CLI 1.0.0') + + # 创建通用参数组 + common_parser = argparse.ArgumentParser(add_help=False) + common_parser.add_argument('--api-key', required=True, help='API key for the language model') + common_parser.add_argument('--model', default='gpt-4o', help='Language model name') + common_parser.add_argument('--base-url', default='https://api.openai.com/v1/', + help='API base URL for the language model') + + # Create main subparsers for different modules + subparsers = parser.add_subparsers(dest='module', help='Available modules') + + # Download module + download_parser = subparsers.add_parser('download', help='Download management') + download_parser.add_argument('--type', choices=['hca', 'scp', 'cxg'], required=True, + help='Downloader type (hca, scp or cxg)') + download_parser.add_argument('--database', required=True, help='Database path') + download_parser.add_argument('--table', required=True, help='Table name') + download_parser.add_argument('--save-dir', required=True, help='Save directory') + download_parser.add_argument('--workers', type=int, default=1, help='Number of parallel downloads') + download_parser.add_argument('--timeout', type=int, default=7200, help='Download timeout in seconds') + download_parser.add_argument('--dcp', help='DCP value for HCA downloader') + download_parser.add_argument('--cookie', help='Cookie file path (JSON format) for SCP downloader') + + # Fetch module + fetch_parser = subparsers.add_parser('fetch', help='Data fetching') + fetch_parser.add_argument('--database', choices=['scp', 'hca', 'cxg'], required=True, + help='Database to fetch from (scp: Single Cell Portal, hca: Human Cell Atlas, cxg: CellxGene)') + fetch_parser.add_argument('--output', required=True, help='Output JSON file path') + fetch_parser.add_argument('--domain', help='Custom domain name (optional)') + fetch_parser.add_argument('--dcp', help='DCP server address (optional)') + + # Add metadata processing module + process_parser = subparsers.add_parser('process', help='Process metadata', parents=[common_parser]) + process_parser.add_argument('--source', required=True, choices=['scp', 'hca', 'cxg'], + help='Source database type') + process_parser.add_argument('--input', required=True, + help='Input JSON file containing metadata') + process_parser.add_argument('--output-dir', required=True, + help='Output directory for processed JSON files') + process_parser.add_argument('--database', required=True, + help='SQLite database path') + process_parser.add_argument('--schema', required=True, + help='JSON schema file path') + process_parser.add_argument('--batch-size', type=int, default=5, + help='Number of items to process in each batch') + process_parser.add_argument('--workers', type=int, default=5, + help='Number of parallel workers') + process_parser.add_argument('--log-file', default='process.log', + help='Log file path') + + # Add Vanna query module + vanna_parser = subparsers.add_parser('vanna', help='Query database using Vanna AI', parents=[common_parser]) + vanna_parser.add_argument('--db-path', required=True, help='SQLite database path') + vanna_parser.add_argument('--question', required=True, help='Question to ask the database') + vanna_parser.add_argument('--table', default='Sample', help='Table name to query') + + # Add retrieve module + retrieve_parser = subparsers.add_parser('retrieve', help='Retrieve query using Vanna AI', parents=[common_parser]) + retrieve_parser.add_argument('--query', required=True, help='Query question to ask the database') + retrieve_parser.add_argument('--db-path', required=True, help='SQLite database path') + retrieve_parser.add_argument('--table', default='Sample', help='Table name to query') + + args = parser.parse_args() + + try: + if args.module == 'download': + downloader_kwargs = { + 'database_path': args.database, + 'table_name': args.table, + 'save_root': args.save_dir, + 'downloader_type': args.type, + 'num_workers': args.workers, + 'timeout': args.timeout + } + + if args.type == 'hca' and args.dcp: + downloader_kwargs['dcp'] = args.dcp + elif args.type == 'scp' and args.cookie: + try: + with open(args.cookie, 'r') as f: + downloader_kwargs['cookie'] = json.load(f) + except (json.JSONDecodeError, FileNotFoundError) as e: + print(f"Error reading cookie file: {e}") + return 1 + + # Create and run downloader + downloader = Downloader(**downloader_kwargs) + asyncio.run(downloader.main()) + + elif args.module == 'fetch': + if args.database == 'scp': + fetcher = SingleCellPortalFetcher( + domain_name=args.domain if args.domain else "singlecell.broadinstitute.org", + ) + fetcher.fetch(args.output) + + elif args.database == 'hca': + fetcher = ExploreDataFetcher(dcp_num=args.dcp) + fetcher.fetch(args.output) + + elif args.database == 'cxg': + fetcher = CellxgeneFetcher( + domain_name=args.domain if args.domain else "cellxgene.cziscience.com/curation/v1" + ) + fetcher.fetch(args.output) + + elif args.module == 'process': + process_metadata(args) + + elif args.module == 'retrieve': + retrieve_query(args) + + else: + parser.print_help() + + except Exception as e: + print(f"Error: {str(e)}") + return 1 + + return 0 + + +if __name__ == '__main__': + exit(main()) diff --git a/requirements.txt b/requirements.txt index ba87509..7596c5c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,6 @@ GEOparse pysradb bs4 ijson -aiofiles \ No newline at end of file +aiofiles +vanna +vanna[chromadb,openai,postgres] \ No newline at end of file