diff --git a/app/api/api_router.py b/app/api/api_router.py index da2d854..f1c7c0f 100644 --- a/app/api/api_router.py +++ b/app/api/api_router.py @@ -2,7 +2,7 @@ from fastapi import APIRouter -from app.api import annotation_api, api_key_api, chat_tab_api, driver_api, test_api, user_db_api +from app.api import annotation_api, api_key_api, chat_tab_api, driver_api, query_api, test_api, user_db_api api_router = APIRouter() @@ -15,3 +15,4 @@ api_router.include_router(api_key_api.router, prefix="/keys", tags=["API Key"]) api_router.include_router(chat_tab_api.router, prefix="/chats", tags=["AI Chat"]) api_router.include_router(annotation_api.router, prefix="/annotations", tags=["Annotation"]) +api_router.include_router(query_api.router, prefix="/query", tags=["query"]) diff --git a/app/api/query_api.py b/app/api/query_api.py new file mode 100644 index 0000000..5b9927b --- /dev/null +++ b/app/api/query_api.py @@ -0,0 +1,70 @@ +# app/api/query_api.py + + +from fastapi import APIRouter, Depends + +from app.core.exceptions import APIException +from app.core.response import ResponseMessage +from app.schemas.query.query_model import QueryInfo, RequestExecutionQuery +from app.services.query_service import QueryService, query_service +from app.services.user_db_service import UserDbService, user_db_service + +query_service_dependency = Depends(lambda: query_service) +user_db_service_dependency = Depends(lambda: user_db_service) + +router = APIRouter() + + +@router.post( + "/execute", + response_model=ResponseMessage[dict | str | None], + summary="쿼리 실행", +) +def execution( + query_info: RequestExecutionQuery, + service: QueryService = query_service_dependency, + userDbservice: UserDbService = user_db_service_dependency, +) -> ResponseMessage[dict | str | None]: + + db_info = userDbservice.find_profile(query_info.user_db_id) + result = service.execution(query_info, db_info) + + if not result.is_successful: + raise APIException(result.code) + return ResponseMessage.success(value=result.data, code=result.code) + + +@router.post( + "/execute/test", + response_model=ResponseMessage[bool], + summary="쿼리 실행", +) +def execution_test( + query_info: QueryInfo, + service: QueryService = query_service_dependency, + userDbservice: UserDbService = user_db_service_dependency, +) -> ResponseMessage[bool]: + + db_info = userDbservice.find_profile(query_info.user_db_id) + result = service.execution_test(query_info, db_info) + + if not result.is_successful: + raise APIException(result.code) + return ResponseMessage.success(value=result.data, code=result.code) + + +@router.get( + "/find/{chat_tab_id}", + response_model=ResponseMessage[dict], + summary="쿼리 실행 내역 조회", +) +def find_query_history( + chat_tab_id: str, + service: QueryService = query_service_dependency, +) -> ResponseMessage[dict]: + + result = service.find_query_history(chat_tab_id) + + if not result.is_successful: + raise APIException(result.code) + return ResponseMessage.success(value=result.data, code=result.code) diff --git a/app/core/enum/db_key_prefix_name.py b/app/core/enum/db_key_prefix_name.py index 0e3fa18..819edee 100644 --- a/app/core/enum/db_key_prefix_name.py +++ b/app/core/enum/db_key_prefix_name.py @@ -9,6 +9,7 @@ class DBSaveIdEnum(Enum): driver = "DRIVER" api_key = "API-KEY" chat_tab = "CHAT_TAB" + query = "QUERY" database_annotation = "DB-ANNO" table_annotation = "TBL-ANNO" diff --git a/app/core/status.py b/app/core/status.py index e1574dd..86e73d9 100644 --- a/app/core/status.py +++ b/app/core/status.py @@ -46,6 +46,9 @@ class CommonCode(Enum): SUCCESS_DELETE_ANNOTATION = (status.HTTP_200_OK, "2402", "어노테이션을 성공적으로 삭제하였습니다.") """ SQL 성공 코드 - 25xx """ + SUCCESS_EXECUTION = (status.HTTP_201_CREATED, "2400", "쿼리를 성공적으로 수행하였습니다.") + SUCCESS_FIND_QUERY_HISTORY = (status.HTTP_200_OK, "2102", "쿼리 이력 조회를 성공하였습니다.") + SUCCESS_EXECUTION_TEST = (status.HTTP_201_CREATED, "2400", "쿼리 TEST를 성공적으로 수행하였습니다.") # ======================================= # 클라이언트 에러 (Client Error) - 4xxx @@ -90,6 +93,8 @@ class CommonCode(Enum): NO_ANNOTATION_FOR_PROFILE = (status.HTTP_404_NOT_FOUND, "4401", "해당 DB 프로필에 연결된 어노테이션이 없습니다.") """ SQL 클라이언트 에러 코드 - 45xx """ + NO_CHAT_KEY = (status.HTTP_400_BAD_REQUEST, "4501", "CHAT 키는 필수 값입니다.") + NO_QUERY = (status.HTTP_400_BAD_REQUEST, "4500", "쿼리는 필수 값입니다.") # ================================== # 서버 에러 (Server Error) - 5xx @@ -139,6 +144,7 @@ class CommonCode(Enum): ) """ SQL 서버 에러 코드 - 55xx """ + FAIL_CREATE_QUERY = (status.HTTP_500_INTERNAL_SERVER_ERROR, "5170", "쿼리 실행 정보 저장 중 에러가 발생했습니다.") def __init__(self, http_status: int, code: str, message: str): """Enum 멤버가 생성될 때 각 값을 속성으로 할당합니다.""" diff --git a/app/db/init_db.py b/app/db/init_db.py index b90cf66..c514815 100644 --- a/app/db/init_db.py +++ b/app/db/init_db.py @@ -170,13 +170,17 @@ def initialize_database(): # --- query_history 테이블 처리 --- query_history_cols = { "id": "VARCHAR(64) PRIMARY KEY NOT NULL", + "user_db_id": "VARCHAR(64) NOT NULL", "chat_message_id": "VARCHAR(64) NOT NULL", + "database": "VARCHAR(256) NOT NULL", "query_text": "TEXT NOT NULL", - "is_success": "VARCHAR(1) NOT NULL", - "error_message": "TEXT NOT NULL", + "type": "VARCHAR(32)", + "is_success": "VARCHAR(1)", + "error_message": "TEXT", "created_at": "DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP", "updated_at": "DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP", "FOREIGN KEY (chat_message_id)": "REFERENCES chat_message(id) ON DELETE CASCADE", + "FOREIGN KEY (user_db_id)": "REFERENCES db_profile(id) ON DELETE CASCADE", } create_sql = f"CREATE TABLE IF NOT EXISTS query_history ({', '.join([f'{k} {v}' for k, v in query_history_cols.items()])})" cursor.execute(create_sql) diff --git a/app/repository/query_repository.py b/app/repository/query_repository.py new file mode 100644 index 0000000..9dda359 --- /dev/null +++ b/app/repository/query_repository.py @@ -0,0 +1,171 @@ +import sqlite3 +from typing import Any + +import oracledb + +from app.core.exceptions import APIException +from app.core.status import CommonCode +from app.core.utils import get_db_path +from app.schemas.query.result_model import ( + BasicResult, + ExecutionResult, + ExecutionSelectResult, + InsertLocalDBResult, + QueryTestResult, + SelectQueryHistoryResult, +) + + +class QueryRepository: + def execution( + self, + query: str, + driver_module: Any, + **kwargs: Any, + ) -> ExecutionSelectResult | ExecutionResult | BasicResult: + """ + 쿼리 수행합니다. + """ + connection = None + try: + connection = self._connect(driver_module, **kwargs) + cursor = connection.cursor() + + cursor.execute(query) + + if self._is_select_query(query): + rows = cursor.fetchall() + + if cursor.description: + columns = [desc[0] for desc in cursor.description] + data = [dict(zip(columns, row, strict=False)) for row in rows] + else: + columns = [] + data = [] + result = {"columns": columns, "data": data} + + return ExecutionSelectResult(is_successful=True, code=CommonCode.SUCCESS_EXECUTION, data=result) + + connection.commit() + return ExecutionResult(is_successful=True, code=CommonCode.SUCCESS_EXECUTION, data=cursor.rowcount) + except (AttributeError, driver_module.OperationalError, driver_module.DatabaseError): + return BasicResult(is_successful=False, code=CommonCode.FAIL_CONNECT_DB) + except Exception: + return BasicResult(is_successful=False, code=CommonCode.FAIL) + finally: + if connection: + connection.close() + + def execution_test( + self, + query: str, + driver_module: Any, + **kwargs: Any, + ) -> QueryTestResult: + """ + 쿼리가 문법적으로 유효한지 테스트합니다. + 실제 데이터는 변경되지 않습니다. (모든 작업은 롤백됩니다). + """ + connection = None + try: + connection = self._connect(driver_module, **kwargs) + cursor = connection.cursor() + cursor.execute(query) + + connection.rollback() + return QueryTestResult(is_successful=True, code=CommonCode.SUCCESS_EXECUTION_TEST, data=True) + except (AttributeError, driver_module.OperationalError, driver_module.DatabaseError): + return QueryTestResult(is_successful=False, code=CommonCode.FAIL_CONNECT_DB, data=False) + except Exception: + return QueryTestResult(is_successful=False, code=CommonCode.FAIL, data=False) + finally: + if connection: + connection.close() + + def create_query_history( + self, + sql: str, + data: tuple, + query: str, + ) -> InsertLocalDBResult: + """ + 쿼리 실행 결과를 저장합니다. + """ + db_path = get_db_path() + connection = None + try: + connection = sqlite3.connect(db_path) + cursor = connection.cursor() + cursor.execute(sql, data) + connection.commit() + + return ExecutionResult(is_successful=True, code=CommonCode.SUCCESS_EXECUTION, data=query) + except sqlite3.Error as e: + raise APIException(CommonCode.FAIL_CONNECT_DB) from e + except Exception as e: + raise APIException(CommonCode.FAIL_CREATE_QUERY) from e + finally: + if connection: + connection.close() + + def find_query_history(self, chat_tab_id: int) -> SelectQueryHistoryResult: + """ + 전달받은 쿼리를 실행하여 모든 DB 연결 정보를 조회합니다. + """ + db_path = get_db_path() + connection = None + try: + connection = sqlite3.connect(db_path) + connection.row_factory = sqlite3.Row + cursor = connection.cursor() + + sql = """ + SELECT qh.* + FROM query_history AS qh + LEFT JOIN chat_message AS cm ON qh.chat_message_id = cm.id + WHERE cm.chat_tab_id = ? + ORDER BY qh.created_at DESC + LIMIT 5; + """ + data = (chat_tab_id,) + + cursor.execute(sql, data) + rows = cursor.fetchall() + + columns = [desc[0] for desc in cursor.description] + data = [dict(zip(columns, row, strict=False)) for row in rows] + result = {"columns": columns, "data": data} + + return SelectQueryHistoryResult(is_successful=True, code=CommonCode.SUCCESS_FIND_QUERY_HISTORY, data=result) + except sqlite3.Error: + return SelectQueryHistoryResult(is_successful=False, code=CommonCode.FAIL_CONNECT_DB) + except Exception: + return SelectQueryHistoryResult(is_successful=False, code=CommonCode.FAIL) + finally: + if connection: + connection.close() + + # ───────────────────────────── + # DB 연결 메서드 + # ───────────────────────────── + def _connect(self, driver_module: Any, **kwargs): + if driver_module is oracledb: + if kwargs.get("user", "").lower() == "sys": + kwargs["mode"] = oracledb.AUTH_MODE_SYSDBA + return driver_module.connect(**kwargs) + elif "connection_string" in kwargs: + return driver_module.connect(kwargs["connection_string"]) + elif "db_name" in kwargs: + return driver_module.connect(kwargs["db_name"]) + else: + return driver_module.connect(**kwargs) + + def _is_select_query(self, query_text: str) -> bool: + for stmt in query_text.split(";"): + cleaned_stmt = stmt.strip().lower() + if cleaned_stmt and not cleaned_stmt.startswith("--") and cleaned_stmt.startswith("select"): + return True + return False + + +query_repository = QueryRepository() diff --git a/app/schemas/query/query_model.py b/app/schemas/query/query_model.py new file mode 100644 index 0000000..bec85a2 --- /dev/null +++ b/app/schemas/query/query_model.py @@ -0,0 +1,70 @@ +# app/schemas/query/query_model.py + +from typing import Any + +from pydantic import BaseModel, Field, model_validator + +from app.core.enum.db_key_prefix_name import DBSaveIdEnum +from app.core.exceptions import APIException +from app.core.status import CommonCode +from app.core.utils import generate_prefixed_uuid + + +def _is_empty(value: Any | None) -> bool: + """값이 None, 빈 문자열, 공백 문자열인지 검사""" + if value is None: + return True + if isinstance(value, str) and not value.strip(): + return True + return False + + +class QueryInfo(BaseModel): + user_db_id: str = Field(..., description="DB Key") + database: str | None = Field(None, description="database 명") + query_text: str | None = Field(None, description="쿼리 내용") + + @model_validator(mode="after") + def validate_required_fields(self) -> "QueryInfo": + """QueryInfo 모델에 대한 필수 필드 유효성 검사""" + if _is_empty(self.user_db_id): + raise APIException(CommonCode.NO_DB_DRIVER) + + if _is_empty(self.query_text): + raise APIException(CommonCode.NO_QUERY) + + return self + + +class RequestExecutionQuery(QueryInfo): + chat_message_id: str | None = Field(None, description="연결된 메시지 Key") + + @model_validator(mode="after") + def validate_chat_message_id(self) -> "RequestExecutionQuery": + """RequestExecutionQuery 모델에만 필요한 추가 필드 유효성 검사""" + if _is_empty(self.chat_message_id): + raise APIException(CommonCode.NO_CHAT_KEY) + + return self + + +class ExecutionQuery(RequestExecutionQuery): + id: str | None = Field(None, description="Query Key 값") + type: str | None = Field(None, description="디비 타입") + is_success: str | None = Field(None, description="성공 여부") + error_message: str | None = Field(None, description="에러 메시지") + + @classmethod + def from_query_info( + cls, query_info: RequestExecutionQuery, type: str, is_success: bool, error_message: str | None = None + ): + return cls( + id=generate_prefixed_uuid(DBSaveIdEnum.query.value), + user_db_id=query_info.user_db_id, + chat_message_id=query_info.chat_message_id, + database=query_info.database, + query_text=query_info.query_text, + type=type, + is_success="Y" if is_success else "N", + error_message=error_message, + ) diff --git a/app/schemas/query/result_model.py b/app/schemas/query/result_model.py new file mode 100644 index 0000000..4559553 --- /dev/null +++ b/app/schemas/query/result_model.py @@ -0,0 +1,41 @@ +# app/schemas/user_db/result_model.py + +from pydantic import BaseModel, Field + +from app.core.status import CommonCode + + +# 기본 반환 모델 +class BasicResult(BaseModel): + is_successful: bool = Field(..., description="성공 여부") + code: CommonCode = Field(None, description="결과 코드") + + +class ExecutionSelectResult(BasicResult): + """DB 조회 결과를 위한 확장 모델""" + + data: dict = Field(..., description="쿼리 조회 후 결과 - 데이터") + + +class ExecutionResult(BasicResult): + """DB 결과를 위한 확장 모델""" + + data: str = Field(..., description="쿼리 수행 후 결과") + + +class InsertLocalDBResult(BasicResult): + """DB 결과를 위한 확장 모델""" + + data: str = Field(..., description="쿼리 수행 후 결과") + + +class SelectQueryHistoryResult(BasicResult): + """DB 결과를 위한 확장 모델""" + + data: dict = Field(..., description="쿼리 이력 조회") + + +class QueryTestResult(BasicResult): + """DB Test 결과를 위한 확장 모델""" + + data: bool = Field(..., description="쿼리 수행 결과") diff --git a/app/services/query_service.py b/app/services/query_service.py new file mode 100644 index 0000000..4a4a651 --- /dev/null +++ b/app/services/query_service.py @@ -0,0 +1,116 @@ +# app/service/query_service.py + +import importlib +import sqlite3 +from typing import Any + +from fastapi import Depends + +from app.core.enum.db_driver import DBTypesEnum +from app.core.exceptions import APIException +from app.core.status import CommonCode +from app.repository.query_repository import QueryRepository, query_repository +from app.schemas.query.query_model import ExecutionQuery, QueryInfo, RequestExecutionQuery +from app.schemas.query.result_model import ( + BasicResult, + ExecutionResult, + ExecutionSelectResult, + QueryTestResult, + SelectQueryHistoryResult, +) +from app.schemas.user_db.db_profile_model import AllDBProfileInfo, DBProfileInfo + +query_repository_dependency = Depends(lambda: query_repository) + + +class QueryService: + def execution( + self, + query_info: RequestExecutionQuery, + db_info: AllDBProfileInfo, + repository: QueryRepository = query_repository, + ) -> ExecutionSelectResult | ExecutionResult | BasicResult: + """ + 쿼리 수행 후 결과를 저장합니다. + """ + driver_module = self._get_driver_module(db_info.type) + connect_kwargs = self._prepare_connection_args(db_info, query_info.database) + result = repository.execution(query_info.query_text, driver_module, **connect_kwargs) + try: + query_history_info = ExecutionQuery.from_query_info(query_info, db_info.type, result.is_successful, None) + sql, data = self._get_create_query_and_data(query_history_info) + repository.create_query_history(sql, data, query_history_info.query_text) + except Exception as e: + raise APIException(CommonCode.FAIL) from e + return result + + def execution_test( + self, query_info: QueryInfo, db_info: AllDBProfileInfo, repository: QueryRepository = query_repository + ) -> QueryTestResult: + """ + 쿼리 수행 후 결과를 저장합니다. + """ + driver_module = self._get_driver_module(db_info.type) + connect_kwargs = self._prepare_connection_args(db_info, query_info.database) + return repository.execution_test(query_info.query_text, driver_module, **connect_kwargs) + + def find_query_history( + self, chat_tab_id: int, repository: QueryRepository = query_repository + ) -> SelectQueryHistoryResult: + """ + 쿼리 기록을 조회합니다. + """ + try: + return repository.find_query_history(chat_tab_id) + except Exception as e: + raise APIException(CommonCode.FAIL) from e + + def _get_driver_module(self, db_type: str): + """ + DB 타입에 따라 동적으로 드라이버 모듈을 로드합니다. + """ + driver_name = DBTypesEnum[db_type.lower()].value + if driver_name == "sqlite3": + return sqlite3 + return importlib.import_module(driver_name) + + def _prepare_connection_args(self, db_info: DBProfileInfo, database_name: str) -> dict[str, Any]: + """ + DB 타입에 따라 연결에 필요한 매개변수를 딕셔너리로 구성합니다. + """ + # SQLite는 별도 처리 + if db_info.type == "sqlite": + return {"db_name": db_info.name} + + # 그 외 DB들은 공통 파라미터로 시작 + kwargs = {"host": db_info.host, "port": db_info.port, "user": db_info.username, "password": db_info.password} + + # DB 이름이 없을 경우, 기본 파라미터만 반환 + if not db_info.name and not database_name: + return kwargs + + # DB 이름이 있다면, 타입에 따라 적절한 파라미터를 추가합니다. + final_db = database_name if database_name else db_info.name + if db_info.type == "postgresql": + kwargs["dbname"] = final_db + elif db_info.type in ["mysql", "mariadb"]: + kwargs["database"] = final_db + elif db_info.type == "oracle": + kwargs["dsn"] = f"{db_info.host}:{db_info.port}/{final_db}" + + return kwargs + + # ───────────────────────────── + # 프로필 CRUD 쿼리 생성 메서드 + # ───────────────────────────── + def _get_create_query_and_data(self, query_info: ExecutionQuery) -> tuple[str, tuple]: + profile_dict = query_info.model_dump() + columns_to_insert = {k: v for k, v in profile_dict.items() if v is not None} + columns = ", ".join(columns_to_insert.keys()) + placeholders = ", ".join(["?"] * len(columns_to_insert)) + sql = f"INSERT INTO query_history ({columns}) VALUES ({placeholders})" + data = tuple(columns_to_insert.values()) + return sql, data + + +query_service = QueryService()