diff --git a/app/api/query_api.py b/app/api/query_api.py index 5b9927b..9504a24 100644 --- a/app/api/query_api.py +++ b/app/api/query_api.py @@ -1,6 +1,8 @@ # app/api/query_api.py +from typing import Any + from fastapi import APIRouter, Depends from app.core.exceptions import APIException @@ -36,20 +38,18 @@ def execution( @router.post( "/execute/test", - response_model=ResponseMessage[bool], + response_model=ResponseMessage[Any], summary="쿼리 실행", ) def execution_test( query_info: QueryInfo, service: QueryService = query_service_dependency, userDbservice: UserDbService = user_db_service_dependency, -) -> ResponseMessage[bool]: +) -> ResponseMessage[Any]: db_info = userDbservice.find_profile(query_info.user_db_id) result = service.execution_test(query_info, db_info) - if not result.is_successful: - raise APIException(result.code) return ResponseMessage.success(value=result.data, code=result.code) diff --git a/app/repository/query_repository.py b/app/repository/query_repository.py index 9dda359..e309e8b 100644 --- a/app/repository/query_repository.py +++ b/app/repository/query_repository.py @@ -72,14 +72,26 @@ def execution_test( cursor = connection.cursor() cursor.execute(query) - connection.rollback() - return QueryTestResult(is_successful=True, code=CommonCode.SUCCESS_EXECUTION_TEST, data=True) - except (AttributeError, driver_module.OperationalError, driver_module.DatabaseError): - return QueryTestResult(is_successful=False, code=CommonCode.FAIL_CONNECT_DB, data=False) - except Exception: - return QueryTestResult(is_successful=False, code=CommonCode.FAIL, data=False) + if not self._is_select_query(query): + return QueryTestResult(is_successful=True, code=CommonCode.SUCCESS_EXECUTION_TEST, data=True) + + rows = cursor.fetchall() + if cursor.description: + columns = [desc[0] for desc in cursor.description] + data = [dict(zip(columns, row, strict=False)) for row in rows] + else: + columns = [] + data = [] + + result = {"columns": columns, "data": data} + return QueryTestResult(is_successful=True, code=CommonCode.SUCCESS_EXECUTION, data=result) + except (AttributeError, driver_module.OperationalError, driver_module.DatabaseError) as e: + return QueryTestResult(is_successful=False, code=CommonCode.FAIL_CONNECT_DB, data=str(e)) + except Exception as e: + return QueryTestResult(is_successful=False, code=CommonCode.FAIL, data=str(e)) finally: if connection: + connection.rollback() connection.close() def create_query_history( diff --git a/app/schemas/query/result_model.py b/app/schemas/query/result_model.py index 4559553..e4d0c62 100644 --- a/app/schemas/query/result_model.py +++ b/app/schemas/query/result_model.py @@ -1,5 +1,7 @@ # app/schemas/user_db/result_model.py +from typing import Any + from pydantic import BaseModel, Field from app.core.status import CommonCode @@ -38,4 +40,4 @@ class SelectQueryHistoryResult(BasicResult): class QueryTestResult(BasicResult): """DB Test 결과를 위한 확장 모델""" - data: bool = Field(..., description="쿼리 수행 결과") + data: Any = Field(..., description="쿼리 수행 결과")