diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..94f480de94 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto eol=lf \ No newline at end of file diff --git a/client b/client index 273fa3c4cc..5b4c9fbf18 160000 --- a/client +++ b/client @@ -1 +1 @@ -Subproject commit 273fa3c4cc7d87f942ade231ce9220b98fa595e2 +Subproject commit 5b4c9fbf18cde5f7c2b3455eef434d96c7ddf771 diff --git a/pom.xml b/pom.xml index 42262c105b..54505aea16 100644 --- a/pom.xml +++ b/pom.xml @@ -196,6 +196,17 @@ ${springboot.version} + + + com.openai + openai-java + 2.16.0 + + + org.jetbrains.kotlin + kotlin-stdlib-jdk7 + 1.9.24 + org.codehaus.groovy diff --git a/server/integration-test/src/test/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTaskManagerTest.java b/server/integration-test/src/test/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTaskManagerTest.java index 2d8b7e1980..d1123b82c7 100644 --- a/server/integration-test/src/test/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTaskManagerTest.java +++ b/server/integration-test/src/test/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTaskManagerTest.java @@ -47,6 +47,7 @@ import com.oceanbase.odc.service.collaboration.project.model.Project; import com.oceanbase.odc.service.connection.database.model.Database; import com.oceanbase.odc.service.connection.model.ConnectionConfig; +import com.oceanbase.odc.service.datasecurity.model.ScanningModeType; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnScanningTaskInfo; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnScanningTaskInfo.ScanningTaskStatus; import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; @@ -107,7 +108,8 @@ public static void tearDown() { public void test_start_groovyRule_OBMySQL() { List databases = createDatabases(ConnectType.OB_MYSQL); List rules = Arrays.asList(createGroovySensitiveRules()); - SensitiveColumnScanningTaskInfo taskInfo = manager.start(databases, rules, mysqlConnectionConfig, null); + SensitiveColumnScanningTaskInfo taskInfo = + manager.start(databases, rules, ScanningModeType.RULES_ONLY, mysqlConnectionConfig, null); await().atMost(20, SECONDS) .until(() -> manager.get(taskInfo.getTaskId()).getStatus() == ScanningTaskStatus.SUCCESS); Assert.assertEquals(2, manager.get(taskInfo.getTaskId()).getSensitiveColumns().size()); @@ -117,7 +119,8 @@ public void test_start_groovyRule_OBMySQL() { public void test_start_groovyRule_OBOracle() { List databases = createDatabases(ConnectType.OB_ORACLE); List rules = Arrays.asList(createGroovySensitiveRules()); - SensitiveColumnScanningTaskInfo taskInfo = manager.start(databases, rules, oracleConnectionConfig, null); + SensitiveColumnScanningTaskInfo taskInfo = + manager.start(databases, rules, ScanningModeType.RULES_ONLY, oracleConnectionConfig, null); await().atMost(20, SECONDS) .until(() -> manager.get(taskInfo.getTaskId()).getStatus() == ScanningTaskStatus.SUCCESS); Assert.assertEquals(2, manager.get(taskInfo.getTaskId()).getSensitiveColumns().size()); @@ -127,7 +130,8 @@ public void test_start_groovyRule_OBOracle() { public void test_start_pathRule_OBMySQL() { List databases = createDatabases(ConnectType.OB_MYSQL); List rules = Arrays.asList(createPathSensitiveRules()); - SensitiveColumnScanningTaskInfo taskInfo = manager.start(databases, rules, mysqlConnectionConfig, null); + SensitiveColumnScanningTaskInfo taskInfo = + manager.start(databases, rules, ScanningModeType.RULES_ONLY, mysqlConnectionConfig, null); await().atMost(20, SECONDS) .until(() -> manager.get(taskInfo.getTaskId()).getStatus() == ScanningTaskStatus.SUCCESS); Assert.assertEquals(20, manager.get(taskInfo.getTaskId()).getSensitiveColumns().size()); @@ -137,7 +141,8 @@ public void test_start_pathRule_OBMySQL() { public void test_start_pathRule_OBMOracle() { List databases = createDatabases(ConnectType.OB_ORACLE); List rules = Arrays.asList(createPathSensitiveRules()); - SensitiveColumnScanningTaskInfo taskInfo = manager.start(databases, rules, oracleConnectionConfig, null); + SensitiveColumnScanningTaskInfo taskInfo = + manager.start(databases, rules, ScanningModeType.RULES_ONLY, oracleConnectionConfig, null); await().atMost(20, SECONDS) .until(() -> manager.get(taskInfo.getTaskId()).getStatus() == ScanningTaskStatus.SUCCESS); Assert.assertEquals(20, manager.get(taskInfo.getTaskId()).getSensitiveColumns().size()); @@ -147,7 +152,8 @@ public void test_start_pathRule_OBMOracle() { public void test_start_RegexRule_OBMySQL() { List databases = createDatabases(ConnectType.OB_MYSQL); List rules = Arrays.asList(createRegexSensitiveRules(ConnectType.OB_MYSQL)); - SensitiveColumnScanningTaskInfo taskInfo = manager.start(databases, rules, mysqlConnectionConfig, null); + SensitiveColumnScanningTaskInfo taskInfo = + manager.start(databases, rules, ScanningModeType.RULES_ONLY, mysqlConnectionConfig, null); await().atMost(20, SECONDS) .until(() -> manager.get(taskInfo.getTaskId()).getStatus() == ScanningTaskStatus.SUCCESS); Assert.assertEquals(6, manager.get(taskInfo.getTaskId()).getSensitiveColumns().size()); @@ -157,7 +163,8 @@ public void test_start_RegexRule_OBMySQL() { public void test_start_RegexRule_OBOracle() { List databases = createDatabases(ConnectType.OB_ORACLE); List rules = Arrays.asList(createRegexSensitiveRules(ConnectType.OB_ORACLE)); - SensitiveColumnScanningTaskInfo taskInfo = manager.start(databases, rules, oracleConnectionConfig, null); + SensitiveColumnScanningTaskInfo taskInfo = + manager.start(databases, rules, ScanningModeType.RULES_ONLY, oracleConnectionConfig, null); await().atMost(20, SECONDS) .until(() -> manager.get(taskInfo.getTaskId()).getStatus() == ScanningTaskStatus.SUCCESS); Assert.assertEquals(6, manager.get(taskInfo.getTaskId()).getSensitiveColumns().size()); diff --git a/server/odc-core/src/main/java/com/oceanbase/odc/core/shared/constant/ErrorCodes.java b/server/odc-core/src/main/java/com/oceanbase/odc/core/shared/constant/ErrorCodes.java index d7954a985e..64e73603c4 100644 --- a/server/odc-core/src/main/java/com/oceanbase/odc/core/shared/constant/ErrorCodes.java +++ b/server/odc-core/src/main/java/com/oceanbase/odc/core/shared/constant/ErrorCodes.java @@ -334,6 +334,16 @@ public enum ErrorCodes implements ErrorCode { ExtractFileFailed, InvalidSignature, + /** + * AI Service + */ + AIServiceNotAvailable, + AIConfigurationIncomplete, + AIClientNotInitialized, + AIInferenceServiceError, + AIResponseFormatError, + AIResponseCountMismatch, + ; diff --git a/server/odc-core/src/main/resources/i18n/ErrorMessages.properties b/server/odc-core/src/main/resources/i18n/ErrorMessages.properties index 89ccb301f4..808054c177 100644 --- a/server/odc-core/src/main/resources/i18n/ErrorMessages.properties +++ b/server/odc-core/src/main/resources/i18n/ErrorMessages.properties @@ -212,4 +212,10 @@ com.oceanbase.odc.ErrorCodes.UpdateNotAllowed=Editing is not allowed in the curr com.oceanbase.odc.ErrorCodes.PauseNotAllowed=Disabling is not allowed in the current state, please check if there are any records in execution. com.oceanbase.odc.ErrorCodes.DeleteNotAllowed=Deletion is not allowed in the current state. com.oceanbase.odc.ErrorCodes.ExtractFileFailed=Failed to extract the file. Please check whether the file is correct. Details: {0} -com.oceanbase.odc.ErrorCodes.InvalidSignature=File verification failed. Do not modify exported files or check if the correct key was used. \ No newline at end of file +com.oceanbase.odc.ErrorCodes.InvalidSignature=File verification failed. Do not modify exported files or check if the correct key was used. +com.oceanbase.odc.ErrorCodes.AIServiceNotAvailable=AI service is not available. Please contact administrator to enable AI service. +com.oceanbase.odc.ErrorCodes.AIConfigurationIncomplete=AI configuration is incomplete. Please contact administrator to configure AI parameters. +com.oceanbase.odc.ErrorCodes.AIClientNotInitialized=AI client is not initialized. Please check AI configuration and restart service. +com.oceanbase.odc.ErrorCodes.AIInferenceServiceError=Failed to call AI inference service. Details: {0} +com.oceanbase.odc.ErrorCodes.AIResponseFormatError=AI response format is invalid. Details: {0} +com.oceanbase.odc.ErrorCodes.AIResponseCountMismatch=AI response count does not match expected count. Expected: {0}, Actual: {1} \ No newline at end of file diff --git a/server/odc-core/src/main/resources/i18n/ErrorMessages_zh_CN.properties b/server/odc-core/src/main/resources/i18n/ErrorMessages_zh_CN.properties index 1a84987660..7bb054aa71 100644 --- a/server/odc-core/src/main/resources/i18n/ErrorMessages_zh_CN.properties +++ b/server/odc-core/src/main/resources/i18n/ErrorMessages_zh_CN.properties @@ -214,4 +214,11 @@ com.oceanbase.odc.ErrorCodes.UnsupportedSyncTableStructure=结构同步暂不支 com.oceanbase.odc.ErrorCodes.ScheduleIntervalTooShort=执行间隔配置过短,请重新配置。最小间隔为:{0} 秒 com.oceanbase.odc.ErrorCodes.ExtractFileFailed=提取文件失败,请确认文件是否正确,错误详情 {0} -com.oceanbase.odc.ErrorCodes.InvalidSignature=文件验签不通过,请勿修改导出文件或检查密钥是否正确 \ No newline at end of file +com.oceanbase.odc.ErrorCodes.InvalidSignature=文件验签不通过,请勿修改导出文件或检查密钥是否正确 + +com.oceanbase.odc.ErrorCodes.AIServiceNotAvailable=AI 服务不可用,请联系管理员启用 AI 服务 +com.oceanbase.odc.ErrorCodes.AIConfigurationIncomplete=AI 配置不完整,请联系管理员配置 AI 参数 +com.oceanbase.odc.ErrorCodes.AIClientNotInitialized=AI 客户端未初始化,请检查 AI 配置并重启服务 +com.oceanbase.odc.ErrorCodes.AIInferenceServiceError=调用 AI 推理服务失败,错误详情:{0} +com.oceanbase.odc.ErrorCodes.AIResponseFormatError=AI 响应格式无效,错误详情:{0} +com.oceanbase.odc.ErrorCodes.AIResponseCountMismatch=AI 响应数量与预期不符,预期:{0},实际:{1} \ No newline at end of file diff --git a/server/odc-migrate/src/main/resources/migrate/common/V_4_3_4_20__alter_sensitive_rule.sql b/server/odc-migrate/src/main/resources/migrate/common/V_4_3_4_20__alter_sensitive_rule.sql new file mode 100644 index 0000000000..163ce9a466 --- /dev/null +++ b/server/odc-migrate/src/main/resources/migrate/common/V_4_3_4_20__alter_sensitive_rule.sql @@ -0,0 +1,7 @@ +-- Add AI-related columns to table `data_security_sensitive_rule` +alter table `data_security_sensitive_rule` + add column `ai_sensitive_types` text default null comment 'A list of sensitive data types for AI rules, stored as a JSON array string.'; + +alter table `data_security_sensitive_rule` + add column `ai_custom_prompt` text default null comment 'User-defined custom prompt for AI rules.'; + diff --git a/server/odc-migrate/src/main/resources/migrate/common/V_4_3_4_21__add_ai_system_config.sql b/server/odc-migrate/src/main/resources/migrate/common/V_4_3_4_21__add_ai_system_config.sql new file mode 100644 index 0000000000..9219e0d17d --- /dev/null +++ b/server/odc-migrate/src/main/resources/migrate/common/V_4_3_4_21__add_ai_system_config.sql @@ -0,0 +1,16 @@ + +INSERT INTO config_system_configuration(`key`, `value`, `description`) +VALUES('odc.ai.enabled', 'false', 'Whether AI feature is enabled, disabled by default') +ON DUPLICATE KEY UPDATE `id`=`id`; + +INSERT INTO config_system_configuration(`key`, `value`, `description`) +VALUES('odc.ai.api-key', '', 'AI API key, required when AI feature is enabled') +ON DUPLICATE KEY UPDATE `id`=`id`; + +INSERT INTO config_system_configuration(`key`, `value`, `description`) +VALUES('odc.ai.base-url', 'https://api.openai.com', 'AI API base URL, defaults to OpenAI official API endpoint') +ON DUPLICATE KEY UPDATE `id`=`id`; + +INSERT INTO config_system_configuration(`key`, `value`, `description`) +VALUES('odc.ai.model', 'gpt-3.5-turbo', 'AI model to use, defaults to gpt-3.5-turbo') +ON DUPLICATE KEY UPDATE `id`=`id`; \ No newline at end of file diff --git a/server/odc-server/src/main/java/com/oceanbase/odc/server/web/controller/v2/AIController.java b/server/odc-server/src/main/java/com/oceanbase/odc/server/web/controller/v2/AIController.java new file mode 100644 index 0000000000..57a2dd3639 --- /dev/null +++ b/server/odc-server/src/main/java/com/oceanbase/odc/server/web/controller/v2/AIController.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.server.web.controller.v2; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import com.oceanbase.odc.core.authority.util.SkipAuthorize; +import com.oceanbase.odc.service.common.response.Responses; +import com.oceanbase.odc.service.common.response.SuccessResponse; +import com.oceanbase.odc.service.datasecurity.ai.AIConfig; +import com.oceanbase.odc.service.datasecurity.ai.AIInferenceService; +import com.oceanbase.odc.service.datasecurity.ai.AIStatusResponse; + +import io.swagger.annotations.Api; +import io.swagger.annotations.ApiOperation; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +@Api(tags = "AI") +@RestController +@RequestMapping("/api/v2/ai") +public class AIController { + + @Autowired + private AIConfig aiConfig; + + @Autowired + private AIInferenceService aiInferenceService; + + + @ApiOperation(value = "Query the status of the AI function", + notes = "Return the status of whether the AI function is enabled and its configuration status") + @SkipAuthorize("AI status is safe to query for authenticated users") + @GetMapping("/status") + public SuccessResponse getAIStatus() { + AIStatusResponse response = new AIStatusResponse(); + response.setEnabled(aiConfig.isEnabled()); + response.setAvailable(aiInferenceService.isAIAvailable()); + response.setModel(aiConfig.getModel()); + response.setBaseUrl(aiConfig.getBaseUrl()); + response.setApiKeyConfigured(aiConfig.getApiKey() != null && !aiConfig.getApiKey().trim().isEmpty()); + + return Responses.success(response); + } +} diff --git a/server/odc-server/src/main/java/com/oceanbase/odc/server/web/controller/v2/SensitiveColumnController.java b/server/odc-server/src/main/java/com/oceanbase/odc/server/web/controller/v2/SensitiveColumnController.java index 9fd98510b3..dd7370ba16 100644 --- a/server/odc-server/src/main/java/com/oceanbase/odc/server/web/controller/v2/SensitiveColumnController.java +++ b/server/odc-server/src/main/java/com/oceanbase/odc/server/web/controller/v2/SensitiveColumnController.java @@ -35,12 +35,14 @@ import com.oceanbase.odc.service.common.response.Responses; import com.oceanbase.odc.service.common.response.SuccessResponse; import com.oceanbase.odc.service.datasecurity.SensitiveColumnService; +import com.oceanbase.odc.service.datasecurity.SingleTableScanTaskManager; import com.oceanbase.odc.service.datasecurity.model.DatabaseWithAllColumns; import com.oceanbase.odc.service.datasecurity.model.QuerySensitiveColumnParams; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumn; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnScanningReq; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnScanningTaskInfo; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnStats; +import com.oceanbase.odc.service.datasecurity.model.SingleTableScanReq; import com.oceanbase.odc.service.datasecurity.model.UpdateSensitiveColumnsReq; import io.swagger.annotations.ApiOperation; @@ -143,4 +145,25 @@ public SuccessResponse getScanningResults(@Path return Responses.success(service.getScanningResults(projectId, taskId)); } + @ApiOperation(value = "stopScanning", notes = "Stop a sensitive column scanning task") + @RequestMapping(value = "/stopScanning", method = RequestMethod.POST) + public SuccessResponse stopScanning(@PathVariable Long projectId, + @RequestParam String taskId) { + return Responses.success(service.stopScanning(projectId, taskId)); + } + + @ApiOperation(value = "getSingleTableScanResult", notes = "Get single table scan result") + @RequestMapping(value = "/singleTableScan/{taskId}/result", method = RequestMethod.GET) + public SuccessResponse getSingleTableScanResult( + @PathVariable Long projectId, + @PathVariable String taskId) { + return Responses.success(service.getSingleTableScanResult(projectId, taskId)); + } + + @ApiOperation(value = "scanSingleTableAsync", notes = "Start an asynchronous single table scan") + @RequestMapping(value = "/scanSingleTableAsync", method = RequestMethod.POST) + public SuccessResponse scanSingleTableAsync(@PathVariable Long projectId, + @RequestBody SingleTableScanReq req) { + return Responses.success(service.scanSingleTableAsync(projectId, req)); + } } diff --git a/server/odc-server/src/test/java/com/oceanbase/odc/supervisor/SupervisorApplicationTest.java b/server/odc-server/src/test/java/com/oceanbase/odc/supervisor/SupervisorApplicationTest.java index 5fb40d6311..6946197185 100644 --- a/server/odc-server/src/test/java/com/oceanbase/odc/supervisor/SupervisorApplicationTest.java +++ b/server/odc-server/src/test/java/com/oceanbase/odc/supervisor/SupervisorApplicationTest.java @@ -15,14 +15,6 @@ */ package com.oceanbase.odc.supervisor; -/** - * @author longpeng.zlp - * @date 2024/12/9 15:59 - */ -/** - * @author longpeng.zlp - * @date 2024/12/9 15:59 - */ import java.io.IOException; import java.net.InetSocketAddress; import java.net.Socket; diff --git a/server/odc-service/pom.xml b/server/odc-service/pom.xml index 5d28d6077c..e811f237ef 100644 --- a/server/odc-service/pom.xml +++ b/server/odc-service/pom.xml @@ -75,6 +75,14 @@ commons-beanutils commons-beanutils + + com.openai + openai-java + + + org.jetbrains.kotlin + kotlin-stdlib-jdk7 + org.apache.commons commons-compress diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/config/CommonSecurityProperties.java b/server/odc-service/src/main/java/com/oceanbase/odc/config/CommonSecurityProperties.java index b2e8bbfb4f..86a78a66db 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/config/CommonSecurityProperties.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/config/CommonSecurityProperties.java @@ -56,7 +56,8 @@ public class CommonSecurityProperties { "/api/v2/internal/file/downloadImportFile", "/api/v2/info", "/api/v2/sso/state", - "/api/v2/encryption/publicKey"}; + "/api/v2/encryption/publicKey", + "/api/v2/ai/status"}; private static final String[] STATIC_RESOURCES = new String[] { "/", diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/metadb/datasecurity/SensitiveRuleEntity.java b/server/odc-service/src/main/java/com/oceanbase/odc/metadb/datasecurity/SensitiveRuleEntity.java index 405452ab90..23cd038918 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/metadb/datasecurity/SensitiveRuleEntity.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/metadb/datasecurity/SensitiveRuleEntity.java @@ -113,4 +113,10 @@ public class SensitiveRuleEntity { @Column(name = "update_time", nullable = false, insertable = false, updatable = false) private Date updateTime; + @Convert(converter = JsonListConverter.class) + @Column(name = "ai_sensitive_types") + private List aiSensitiveTypes; + + @Column(name = "ai_custom_prompt") + private String aiCustomPrompt; } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/MaskingAlgorithmService.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/MaskingAlgorithmService.java index 6a256fa0c8..744ed3e3e3 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/MaskingAlgorithmService.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/MaskingAlgorithmService.java @@ -207,6 +207,21 @@ public Long getDefaultAlgorithmIdByOrganizationId(@NonNull Long organizationId) return entities.get(0).getId(); } + @SkipAuthorize("odc internal usages") + public Optional getAlgorithmIdByName(@NonNull String algorithmName, @NonNull Long organizationId) { + List entities = + algorithmRepository.findByNameAndOrganizationId(algorithmName, organizationId); + if (entities.isEmpty()) { + log.warn("No masking algorithm found with name: {} for organization: {}", algorithmName, organizationId); + return Optional.empty(); + } + if (entities.size() > 1) { + log.warn("Multiple masking algorithms found with name: {} for organization: {}, using the first one", + algorithmName, organizationId); + } + return Optional.of(entities.get(0).getId()); + } + @SkipAuthorize("odc internal usages") public List getMaskingAlgorithmsByOrganizationId(@NonNull Long organizationId) { return organizationId2Algorithms.get(organizationId); diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnRecognizer.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnRecognizer.java deleted file mode 100644 index 6e7c4a979b..0000000000 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnRecognizer.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (c) 2023 OceanBase. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.oceanbase.odc.service.datasecurity; - -import java.util.ArrayList; -import java.util.List; - -import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; -import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; -import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; -import com.oceanbase.tools.dbbrowser.model.DBTableColumn; - -/** - * @author gaoda.xy - * @date 2023/5/30 10:30 - */ -public class SensitiveColumnRecognizer implements ColumnRecognizer { - - private Long sensitiveRuleId; - private Long maskingAlgorithmId; - private SensitiveLevel sensitiveLevel; - private final List sensitiveRules; - private final List recognizers; - - public SensitiveColumnRecognizer(List rules) { - this.sensitiveRules = rules; - this.recognizers = new ArrayList<>(); - for (SensitiveRule rule : this.sensitiveRules) { - this.recognizers.add(ColumnRecognizerFactory.create(rule)); - } - } - - public Long sensitiveRuleId() { - return this.sensitiveRuleId; - } - - public Long maskingAlgorithmId() { - return this.maskingAlgorithmId; - } - - public SensitiveLevel sensitiveLevel() { - return this.sensitiveLevel; - } - - @Override - public boolean recognize(DBTableColumn column) { - for (int i = 0; i < recognizers.size(); i++) { - if (recognizers.get(i).recognize(column)) { - SensitiveRule rule = sensitiveRules.get(i); - this.sensitiveRuleId = rule.getId(); - this.maskingAlgorithmId = rule.getMaskingAlgorithmId(); - this.sensitiveLevel = rule.getLevel(); - return true; - } - } - return false; - } - -} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanner.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanner.java new file mode 100644 index 0000000000..236c3550f8 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanner.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import com.oceanbase.odc.service.datasecurity.factory.ColumnRecognizerFactory; +import com.oceanbase.odc.service.datasecurity.factory.ScanningStrategyFactory; +import com.oceanbase.odc.service.datasecurity.model.ScanResult; +import com.oceanbase.odc.service.datasecurity.model.ScanningModeType; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; +import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; +import com.oceanbase.odc.service.datasecurity.strategy.ScanningStrategy; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +public class SensitiveColumnScanner { + + private final List basicRecognizers; + private final List aiRecognizers; + private final ScanningStrategyFactory strategyFactory; + + public SensitiveColumnScanner(List rules, ScanningStrategyFactory strategyFactory) { + this.basicRecognizers = rules.stream() + .filter(r -> r.getType() != SensitiveRuleType.AI) + .map(ColumnRecognizerFactory::create) + .collect(Collectors.toList()); + this.aiRecognizers = rules.stream() + .filter(r -> r.getType() == SensitiveRuleType.AI) + .map(ColumnRecognizerFactory::create) + .collect(Collectors.toList()); + this.strategyFactory = strategyFactory; + } + + public ScanResult scan(DBTableColumn column, ScanningModeType mode) { + ScanningStrategy strategy = strategyFactory.getStrategy(mode); + return strategy.scan(column, basicRecognizers, aiRecognizers); + } + + public Map scanBatch(List columns, ScanningModeType mode) { + ScanningStrategy strategy = strategyFactory.getStrategy(mode); + return strategy.scanBatch(columns, basicRecognizers, aiRecognizers); + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTask.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTask.java index 00d3ccc390..2dd234a21c 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTask.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTask.java @@ -20,81 +20,201 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.stream.Collectors; import com.oceanbase.odc.core.shared.constant.ErrorCodes; +import com.oceanbase.odc.service.common.util.SpringContextUtil; import com.oceanbase.odc.service.connection.database.model.Database; +import com.oceanbase.odc.service.datasecurity.factory.ScanningStrategyFactory; +import com.oceanbase.odc.service.datasecurity.model.DefaultSensitiveType; +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.ScanResult; +import com.oceanbase.odc.service.datasecurity.model.ScanningModeType; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumn; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnMeta; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnScanningTaskInfo; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnScanningTaskInfo.ScanningTaskStatus; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnType; import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; import com.oceanbase.tools.dbbrowser.model.DBTableColumn; +import lombok.extern.slf4j.Slf4j; + /** * @author gaoda.xy * @date 2023/5/25 14:43 */ +@Slf4j public class SensitiveColumnScanningTask implements Callable { private final Database database; - private final SensitiveColumnRecognizer recognizer; + private final SensitiveColumnScanner scanner; + private final ScanningModeType scanningMode; private final SensitiveColumnScanningTaskInfo taskInfo; private final Map> table2Columns; private final Map> view2Columns; private final Set existsSensitiveColumns; + private final Map ruleMap; - public SensitiveColumnScanningTask(Database database, List rules, + public SensitiveColumnScanningTask(Database database, List rules, ScanningModeType scanningMode, SensitiveColumnScanningTaskInfo taskInfo, List existsSensitiveColumns, Map> table2Columns, Map> view2Columns) { this.database = database; - this.recognizer = new SensitiveColumnRecognizer(rules); + this.scanningMode = scanningMode; + ScanningStrategyFactory strategyFactory = new ScanningStrategyFactory(); + this.scanner = new SensitiveColumnScanner(rules, strategyFactory); this.table2Columns = table2Columns; this.view2Columns = view2Columns; this.taskInfo = taskInfo; this.existsSensitiveColumns = new HashSet<>(existsSensitiveColumns); + this.ruleMap = rules.stream().collect(Collectors.toMap(SensitiveRule::getId, Function.identity())); + } + + private String getColumnKey(DBTableColumn column) { + return String.format("%s.%s.%s", + column.getSchemaName() != null ? column.getSchemaName() : "unknown_schema", + column.getTableName() != null ? column.getTableName() : "unknown_table", + column.getName() != null ? column.getName() : "unknown_column"); } @Override - public Void call() throws Exception { + public Void call() { try { taskInfo.setStatus(ScanningTaskStatus.RUNNING); scanColumns(table2Columns, SensitiveColumnType.TABLE_COLUMN); + if (taskInfo.isCancelled()) { + return null; + } scanColumns(view2Columns, SensitiveColumnType.VIEW_COLUMN); } catch (Exception e) { - taskInfo.setCompleteTime(new Date()); - taskInfo.setStatus(ScanningTaskStatus.FAILED); - taskInfo.setErrorCode(ErrorCodes.Unexpected); - taskInfo.setErrorMsg(String.format("Some errors happen when scanning sensitive column, database=%s", - database.getName())); + if (!taskInfo.isCancelled()) { + taskInfo.setStatus(ScanningTaskStatus.FAILED); + taskInfo.setErrorCode(ErrorCodes.Unexpected); + taskInfo.setErrorMsg(String.format("Error during sensitive column scanning on database=%s, reason=%s", + database.getName(), e.getMessage())); + taskInfo.setCompleteTime(new Date()); + } } return null; } private void scanColumns(Map> object2Columns, SensitiveColumnType columnType) { - for (String objectName : object2Columns.keySet()) { - List sensitiveColumns = new ArrayList<>(); - for (DBTableColumn dbTableColumn : object2Columns.get(objectName)) { - if (recognizer.recognize(dbTableColumn) && !existsSensitiveColumns - .contains(new SensitiveColumnMeta(database.getId(), objectName, dbTableColumn.getName()))) { - SensitiveColumn column = new SensitiveColumn(); - column.setType(columnType); - column.setDatabase(database); - column.setTableName(objectName); - column.setColumnName(dbTableColumn.getName()); - column.setMaskingAlgorithmId(recognizer.maskingAlgorithmId()); - column.setSensitiveRuleId(recognizer.sensitiveRuleId()); - column.setLevel(recognizer.sensitiveLevel()); - sensitiveColumns.add(column); - existsSensitiveColumns - .add(new SensitiveColumnMeta(database.getId(), objectName, dbTableColumn.getName())); + if (object2Columns.isEmpty()) { + return; + } + + List> tableFutures = object2Columns.entrySet().stream() + .map(entry -> CompletableFuture.runAsync(() -> { + String objectName = entry.getKey(); + List columns = entry.getValue(); + + try { + if (taskInfo.isCancelled()) { + return; + } + Map scanResults = this.scanner.scanBatch(columns, this.scanningMode); + if (taskInfo.isCancelled()) { + return; + } + + List sensitiveColumns = new ArrayList<>(); + for (DBTableColumn dbTableColumn : columns) { + String columnKey = getColumnKey(dbTableColumn); + ScanResult scanResult = scanResults.get(columnKey); + + if (scanResult != null) { + Optional finalResultOpt = scanResult + .getFinalResult(this.scanningMode); + finalResultOpt.ifPresent(finalResult -> { + SensitiveColumnMeta meta = new SensitiveColumnMeta(database.getId(), objectName, + dbTableColumn.getName()); + synchronized (existsSensitiveColumns) { + if (!existsSensitiveColumns.contains(meta)) { + SensitiveColumn column = createSensitiveColumn(columnType, objectName, + dbTableColumn, + finalResult); + sensitiveColumns.add(column); + existsSensitiveColumns.add(meta); + } + } + }); + } + } + if (!sensitiveColumns.isEmpty()) { + taskInfo.addSensitiveColumns(sensitiveColumns); + } + taskInfo.addFinishedTableCount(); + } catch (Exception e) { + log.error("Failed to scan table {}: {}", objectName, e.getMessage(), e); + taskInfo.addFinishedTableCount(); + } + })) + .collect(Collectors.toList()); + + CompletableFuture.allOf(tableFutures.toArray(new CompletableFuture[0])).join(); + } + + private SensitiveColumn createSensitiveColumn(SensitiveColumnType columnType, String objectName, + DBTableColumn dbTableColumn, RecognitionResult result) { + SensitiveColumn column = new SensitiveColumn(); + column.setType(columnType); + column.setDatabase(database); + column.setTableName(objectName); + column.setColumnName(dbTableColumn.getName()); + column.setSensitiveRuleId(result.getMatchedRuleId()); + column.setLevel(result.getLevel()); + Long maskingAlgorithmId = determineMaskingAlgorithmId(result); + column.setMaskingAlgorithmId(maskingAlgorithmId); + + return column; + } + + private Long determineMaskingAlgorithmId(RecognitionResult result) { + SensitiveRule matchedRule = this.ruleMap.get(result.getMatchedRuleId()); + if (matchedRule == null) { + return getSystemDefaultAlgorithmId(); + } + + if (SensitiveRuleType.AI.equals(result.getSourceRuleType()) && result.getSensitiveType() != null) { + return handleAiRecognitionResult(result.getSensitiveType()); + } + + return matchedRule.getMaskingAlgorithmId(); + } + + private Long handleAiRecognitionResult(String sensitiveType) { + if (DefaultSensitiveType.isDefaultType(sensitiveType)) { + Optional algorithmNameOpt = DefaultSensitiveType.getAlgorithmNameBySensitiveType(sensitiveType); + if (algorithmNameOpt.isPresent()) { + try { + MaskingAlgorithmService algorithmService = SpringContextUtil.getBean(MaskingAlgorithmService.class); + Optional algorithmIdOpt = algorithmService.getAlgorithmIdByName(algorithmNameOpt.get(), + database.getOrganizationId()); + if (algorithmIdOpt.isPresent()) { + return algorithmIdOpt.get(); + } + } catch (Exception e) { + log.error("Failed to get algorithm ID by name: {}", e.getMessage(), e); } } - taskInfo.addSensitiveColumns(sensitiveColumns); - taskInfo.addFinishedTableCount(); } + + return getSystemDefaultAlgorithmId(); } + private Long getSystemDefaultAlgorithmId() { + try { + MaskingAlgorithmService algorithmService = SpringContextUtil.getBean(MaskingAlgorithmService.class); + return algorithmService.getDefaultAlgorithmIdByOrganizationId(database.getOrganizationId()); + } catch (Exception e) { + log.error("Failed to get default masking algorithm ID: {}", e.getMessage(), e); + return null; + } + } } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTaskManager.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTaskManager.java index 72a93c1676..09102f78ca 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTaskManager.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnScanningTaskManager.java @@ -37,6 +37,7 @@ import com.oceanbase.odc.core.shared.constant.ResourceType; import com.oceanbase.odc.service.connection.database.model.Database; import com.oceanbase.odc.service.connection.model.ConnectionConfig; +import com.oceanbase.odc.service.datasecurity.model.ScanningModeType; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnMeta; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnScanningTaskInfo; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnScanningTaskInfo.ScanningTaskStatus; @@ -65,6 +66,7 @@ public class SensitiveColumnScanningTaskManager { private StatefulUuidStateIdGenerator statefulUuidStateIdGenerator; public SensitiveColumnScanningTaskInfo start(List databases, List rules, + ScanningModeType scanningMode, ConnectionConfig connectionConfig, Map> databaseId2SensitiveColumns) { ConnectionSession session = new DefaultConnectSessionFactory(connectionConfig).generateSession(); try { @@ -101,8 +103,8 @@ public SensitiveColumnScanningTaskInfo start(List databases, List()), + SensitiveColumnScanningTask subTask = new SensitiveColumnScanningTask(database, rules, scanningMode, + taskInfo, sensitiveColumns, database2Table2ColumnsList.getOrDefault(database, new HashMap<>()), database2View2ColumnsList.getOrDefault(database, new HashMap<>())); try { executor.submit(subTask); @@ -132,4 +134,17 @@ public SensitiveColumnScanningTaskInfo get(String taskId) { return taskInfo; } + public boolean stop(String taskId) { + SensitiveColumnScanningTaskInfo taskInfo = cache.get(taskId); + if (taskInfo != null && taskInfo.getStatus() == ScanningTaskStatus.RUNNING) { + taskInfo.setCancelled(true); + taskInfo.setStatus(ScanningTaskStatus.CANCELLED); + taskInfo.setCompleteTime(new Date()); + log.info("Sensitive column scanning task stopped, taskId: {}", taskId); + return true; + } + return false; + } + } + diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnService.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnService.java index 6569b068d9..950fd57888 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnService.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveColumnService.java @@ -23,7 +23,9 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; +import java.util.UUID; import java.util.stream.Collectors; import javax.validation.Valid; @@ -61,15 +63,22 @@ import com.oceanbase.odc.service.connection.database.model.Database; import com.oceanbase.odc.service.connection.model.ConnectionConfig; import com.oceanbase.odc.service.datasecurity.extractor.model.DBColumn; +import com.oceanbase.odc.service.datasecurity.factory.ScanningStrategyFactory; import com.oceanbase.odc.service.datasecurity.model.DatabaseWithAllColumns; import com.oceanbase.odc.service.datasecurity.model.MaskingAlgorithm; import com.oceanbase.odc.service.datasecurity.model.QuerySensitiveColumnParams; +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.ScanResult; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumn; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnMeta; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnScanningReq; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnScanningTaskInfo; import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnStats; +import com.oceanbase.odc.service.datasecurity.model.SensitiveColumnType; +import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; +import com.oceanbase.odc.service.datasecurity.model.SingleTableScanReq; import com.oceanbase.odc.service.datasecurity.util.SensitiveColumnMapper; import com.oceanbase.odc.service.db.browser.DBSchemaAccessors; import com.oceanbase.odc.service.feature.VersionDiffConfigService; @@ -115,6 +124,8 @@ public class SensitiveColumnService { private HorizontalDataPermissionValidator permissionValidator; @Autowired private VersionDiffConfigService versionDiffConfigService; + @Autowired + private SingleTableScanTaskManager singleTableScanTaskManager; @Transactional(rollbackFor = Exception.class) @PreAuthenticate(hasAnyResourceRole = {"OWNER, DBA, SECURITY_ADMINISTRATOR"}, @@ -399,7 +410,8 @@ public SensitiveColumnScanningTaskInfo startScanning(@NotNull Long projectId, PreConditions.notEmpty(rules, "sensitiveRules"); ConnectionConfig connectionConfig = databaseService.findDataSourceForConnectById(databases.get(0).getId()); Map> databaseId2SensitiveColumns = listExistSensitiveColumns(databaseIds); - return scanningTaskManager.start(databases, rules, connectionConfig, databaseId2SensitiveColumns); + return scanningTaskManager.start(databases, rules, req.getScanningMode(), connectionConfig, + databaseId2SensitiveColumns); } @Transactional(rollbackFor = Exception.class) @@ -416,6 +428,21 @@ public SensitiveColumnScanningTaskInfo getScanningResults(@NotNull Long projectI return taskInfo; } + @Transactional(rollbackFor = Exception.class) + @PreAuthenticate(hasAnyResourceRole = {"OWNER, DBA, SECURITY_ADMINISTRATOR"}, + actions = {"OWNER", "DBA", "SECURITY_ADMINISTRATOR"}, resourceType = "ODC_PROJECT", + indexOfIdParam = 0) + @StatefulRoute(stateName = StateName.UUID_STATEFUL_ID, stateIdExpression = "#taskId") + public Boolean stopScanning(@NotNull Long projectId, @NotBlank String taskId) { + SensitiveColumnScanningTaskInfo taskInfo = scanningTaskManager.get(taskId); + if (!Objects.equals(taskInfo.getProjectId(), projectId)) { + String errorMsg = String.format("Sensitive column scanning task not exists, taskId=%s", taskId); + throw new NotFoundException(ErrorCodes.IllegalArgument, new Object[] {"taskId", errorMsg}, null); + } + return scanningTaskManager.stop(taskId); + } + + @SkipAuthorize("odc internal usages") public SensitiveColumnEntity nullSafeGet(@NotNull Long id) { return repository.findById(id) @@ -527,4 +554,126 @@ private Map> getFilteringExistColumns(Long databaseI return filtered; } + @Transactional(rollbackFor = Exception.class) + @PreAuthenticate(hasAnyResourceRole = {"OWNER, DBA, SECURITY_ADMINISTRATOR"}, actions = {"OWNER", "DBA", + "SECURITY_ADMINISTRATOR"}, resourceType = "ODC_PROJECT", indexOfIdParam = 0) + public String scanSingleTableAsync(@NotNull Long projectId, @NotNull @Valid SingleTableScanReq req) { + final Long currentUserId = authenticationFacade.currentUserId(); + final Long currentOrganizationId = authenticationFacade.currentOrganizationId(); + final String currentUserAccountName = authenticationFacade.currentUserAccountName(); + String taskId = UUID.randomUUID().toString(); + singleTableScanTaskManager.startTask(taskId, () -> { + try { + com.oceanbase.odc.service.iam.util.SecurityContextUtils.setCurrentUser( + currentUserId, currentOrganizationId, currentUserAccountName); + + List result = performSingleTableScan(projectId, req); + singleTableScanTaskManager.setTaskResult(taskId, result); + } catch (Exception e) { + log.error("Single table scan failed for taskId: {}, projectId: {}, databaseId: {}, tableName: {}", + taskId, projectId, req.getDatabaseId(), req.getTableName(), e); + String errorMessage = e.getMessage() != null ? e.getMessage() + : "An unknown error occurred during the scanning process: " + e.getClass().getSimpleName(); + singleTableScanTaskManager.setTaskError(taskId, errorMessage); + } + }); + return taskId; + } + + @PreAuthenticate(hasAnyResourceRole = {"OWNER, DBA, SECURITY_ADMINISTRATOR"}, actions = {"OWNER", "DBA", + "SECURITY_ADMINISTRATOR"}, resourceType = "ODC_PROJECT", indexOfIdParam = 0) + public SingleTableScanTaskManager.SingleTableScanTask getSingleTableScanResult(@NotNull Long projectId, + @NotBlank String taskId) { + return singleTableScanTaskManager.getTask(taskId); + } + + private List performSingleTableScan(@NotNull Long projectId, + @NotNull @Valid SingleTableScanReq req) { + Database database = databaseService.detail(req.getDatabaseId()); + PreConditions.notNull(database, "database"); + checkProjectDatabases(projectId, Collections.singletonList(req.getDatabaseId())); + ConnectionConfig connectionConfig = connectionService + .getForConnectionSkipPermissionCheck(database.getDataSource().getId()); + List tableColumns = getTableColumns(connectionConfig, database.getName(), req.getTableName()); + if (CollectionUtils.isEmpty(tableColumns)) { + return Collections.emptyList(); + } + List rules = getScanningRules(projectId, null); + if (CollectionUtils.isEmpty(rules)) { + return Collections.emptyList(); + } + ScanningStrategyFactory strategyFactory = new ScanningStrategyFactory(); + SensitiveColumnScanner scanner = new SensitiveColumnScanner(rules, strategyFactory); + Map scanResults = scanner.scanBatch(tableColumns, req.getScanningMode()); + + List results = new ArrayList<>(); + for (DBTableColumn column : tableColumns) { + String columnKey = String.format("%s.%s.%s", + column.getSchemaName() != null ? column.getSchemaName() : "unknown_schema", + column.getTableName() != null ? column.getTableName() : "unknown_table", + column.getName() != null ? column.getName() : "unknown_column"); + + ScanResult scanResult = scanResults.get(columnKey); + if (scanResult != null) { + Optional finalResult = scanResult.getFinalResult(req.getScanningMode()); + + if (finalResult.isPresent()) { + RecognitionResult result = finalResult.get(); + SensitiveColumn sensitiveColumn = new SensitiveColumn(); + sensitiveColumn.setDatabase(database); + sensitiveColumn.setTableName(column.getTableName()); + sensitiveColumn.setColumnName(column.getName()); + sensitiveColumn.setType(SensitiveColumnType.TABLE_COLUMN); + sensitiveColumn.setEnabled(true); + sensitiveColumn.setSensitiveRuleId(result.getMatchedRuleId()); + sensitiveColumn.setLevel(result.getLevel()); + SensitiveRule matchedRule = rules.stream() + .filter(r -> r.getId().equals(result.getMatchedRuleId())) + .findFirst() + .orElse(null); + if (matchedRule != null && matchedRule.getMaskingAlgorithmId() != null) { + sensitiveColumn.setMaskingAlgorithmId(matchedRule.getMaskingAlgorithmId()); + } else { + sensitiveColumn.setMaskingAlgorithmId(1L); + } + results.add(sensitiveColumn); + } + } + } + + return results; + } + + private List getTableColumns(ConnectionConfig connectionConfig, String databaseName, + String tableName) { + ConnectionSession session = new DefaultConnectSessionFactory(connectionConfig).generateSession(); + try { + DBSchemaAccessor accessor = DBSchemaAccessors.create(session); + return accessor.listTableColumns(databaseName, tableName); + } finally { + session.expire(); + } + } + + private List getScanningRules(Long projectId, List sensitiveRuleIds) { + SensitiveRule defaultRule = createDefaultScanningRule(); + return Collections.singletonList(defaultRule); + } + + private SensitiveRule createDefaultScanningRule() { + SensitiveRule rule = new SensitiveRule(); + rule.setId(-1L); + rule.setName("Single Table Scan Default Rule"); + rule.setEnabled(true); + rule.setType(SensitiveRuleType.AI); + rule.setAiSensitiveTypes(null); + rule.setAiCustomPrompt(null); + Long organizationId = authenticationFacade.currentOrganizationId(); + Long defaultAlgorithmId = algorithmService.getDefaultAlgorithmIdByOrganizationId(organizationId); + rule.setMaskingAlgorithmId(defaultAlgorithmId); + rule.setLevel(SensitiveLevel.MEDIUM); + rule.setBuiltin(true); + return rule; + } } + diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveRuleService.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveRuleService.java index 0bc0fb9b05..22b97dd0bd 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveRuleService.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SensitiveRuleService.java @@ -160,6 +160,8 @@ public SensitiveRule update(@NotNull Long projectId, @NotNull Long id, @NotNull entity.setLevel(rule.getLevel()); entity.setMaskingAlgorithmId(algorithmId); entity.setDescription(rule.getDescription()); + entity.setAiSensitiveTypes(rule.getAiSensitiveTypes()); + entity.setAiCustomPrompt(rule.getAiCustomPrompt()); ruleRepository.saveAndFlush(entity); log.info("Sensitive rule has been updated, id={}, name={}", entity.getId(), entity.getName()); return detail(entity.getProjectId(), entity.getId()); diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SingleTableScanTaskManager.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SingleTableScanTaskManager.java new file mode 100644 index 0000000000..2472426bc6 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/SingleTableScanTaskManager.java @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity; + +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.stereotype.Component; + +import com.oceanbase.odc.service.datasecurity.model.SensitiveColumn; + +import lombok.Data; +import lombok.extern.slf4j.Slf4j; + +/** + * Single-table scan task manager + */ +@Slf4j +@Component +public class SingleTableScanTaskManager { + + private final Map tasks = new ConcurrentHashMap<>(); + + @Autowired + @Qualifier("scanSensitiveColumnExecutor") + private ThreadPoolTaskExecutor executor; + + public String startTask(String taskId, Runnable scanTask) { + SingleTableScanTask task = new SingleTableScanTask(taskId); + tasks.put(taskId, task); + executor.submit(() -> { + try { + task.setStatus(TaskStatus.RUNNING); + scanTask.run(); + task.setStatus(TaskStatus.COMPLETED); + } catch (Exception e) { + log.error("Single table scan task failed, taskId={}", taskId, e); + task.setStatus(TaskStatus.FAILED); + task.setErrorMessage(e.getMessage()); + } + }); + + return taskId; + } + + public String startTask(Runnable scanTask) { + String taskId = UUID.randomUUID().toString(); + return startTask(taskId, scanTask); + } + + public SingleTableScanTask getTask(String taskId) { + return tasks.get(taskId); + } + + public void setTaskResult(String taskId, List result) { + SingleTableScanTask task = tasks.get(taskId); + if (task != null) { + task.setResult(result); + } + } + + public void setTaskError(String taskId, String errorMessage) { + SingleTableScanTask task = tasks.get(taskId); + if (task != null) { + task.setStatus(TaskStatus.FAILED); + task.setErrorMessage(errorMessage); + } + } + + public void cleanupTask(String taskId) { + tasks.remove(taskId); + } + + public enum TaskStatus { + PENDING, RUNNING, COMPLETED, FAILED + } + + @Data + public static class SingleTableScanTask { + private final String taskId; + private TaskStatus status = TaskStatus.PENDING; + private List result; + private String errorMessage; + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIConfig.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIConfig.java new file mode 100644 index 0000000000..a8a81a27b4 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIConfig.java @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.ai; + +import java.util.HashMap; +import java.util.Map; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.stereotype.Component; + +import com.oceanbase.odc.core.shared.constant.ErrorCodes; +import com.oceanbase.odc.core.shared.exception.BadRequestException; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonBoolean; +import com.openai.core.JsonNumber; +import com.openai.core.JsonValue; + +import lombok.Data; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +@Data +@Component +public class AIConfig { + @Value("${odc.ai.enabled:false}") + private boolean enabled; + + @Value("${odc.ai.api-key:}") + private String apiKey; + + @Value("${odc.ai.base-url:https://api.openai.com}") + private String baseUrl; + + @Value("${odc.ai.model:gpt-3.5-turbo}") + private String model; + + private Boolean enableThinking = AIParam.DEFAULT_ENABLE_THINKING; + + private Double temperature = AIParam.DEFAULT_TEMPERATURE; + + private Double topP = AIParam.DEFAULT_TOP_P; + + private Integer topK = AIParam.DEFAULT_TOP_K; + + private Integer minP = AIParam.DEFAULT_MIN_P; + + public Map loadAdditionalParams() { + Map params = new HashMap<>(); + params.put("enable_thinking", JsonBoolean.from(this.enableThinking)); + params.put("top_k", JsonNumber.from(this.topK)); + params.put("min_p", JsonNumber.from(this.minP)); + return params; + } + + @Bean + @ConditionalOnProperty(name = "odc.ai.enabled", havingValue = "true") + public OpenAIClient openAIClient() { + if (apiKey == null || apiKey.trim().isEmpty()) { + throw new BadRequestException(ErrorCodes.AIConfigurationIncomplete, + new Object[] {"API key is not configured"}, + "AI service is enabled but API key is not configured. Please set odc.ai.api-key configuration."); + } + return OpenAIOkHttpClient.builder() + .apiKey(this.apiKey) + .baseUrl(this.baseUrl) + .build(); + } + + public boolean isAIAvailable() { + return enabled && apiKey != null && !apiKey.trim().isEmpty(); + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIInferenceService.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIInferenceService.java new file mode 100644 index 0000000000..a1bf62e906 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIInferenceService.java @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.ai; + +import java.util.Optional; + +import org.springframework.stereotype.Service; + +import com.oceanbase.odc.core.shared.constant.ErrorCodes; +import com.oceanbase.odc.core.shared.exception.BadRequestException; +import com.openai.client.OpenAIClient; +import com.openai.models.chat.completions.ChatCompletion; +import com.openai.models.chat.completions.ChatCompletionCreateParams; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +@Service +public class AIInferenceService { + + private final AIConfig aiConfig; + private final Optional openAIClient; + + public AIInferenceService(AIConfig aiConfig, Optional openAIClient) { + this.aiConfig = aiConfig; + this.openAIClient = openAIClient; + } + + private void checkAIAvailability() { + if (!aiConfig.isEnabled()) { + throw new BadRequestException(ErrorCodes.AIServiceNotAvailable, new Object[] {"AI service is not enabled"}, + "AI service is not enabled. Please contact administrator to enable AI service."); + } + if (!aiConfig.isAIAvailable()) { + throw new BadRequestException(ErrorCodes.AIConfigurationIncomplete, + new Object[] {"AI configuration is incomplete"}, + "AI configuration is incomplete. Please contact administrator to configure AI parameters."); + } + if (!openAIClient.isPresent()) { + throw new BadRequestException(ErrorCodes.AIClientNotInitialized, + new Object[] {"AI client is not initialized"}, + "AI client is not initialized. Please check AI configuration and restart service."); + } + } + + public ChatCompletion chat(String systemPrompt, String userPrompt) { + checkAIAvailability(); + + try { + ChatCompletionCreateParams params = ChatCompletionCreateParams.builder() + .addSystemMessage(systemPrompt) + .addUserMessage(userPrompt) + .model(aiConfig.getModel()) + .temperature(aiConfig.getTemperature()) + .topP(aiConfig.getTopP()) + .additionalBodyProperties(aiConfig.loadAdditionalParams()) + .build(); + return openAIClient.get().chat().completions().create(params); + } catch (Exception e) { + throw new BadRequestException(ErrorCodes.AIInferenceServiceError, new Object[] {e.getMessage()}, + "Failed to call AI inference service: " + e.getMessage(), e); + } + } + + public boolean isAIAvailable() { + return aiConfig.isEnabled() && aiConfig.isAIAvailable() && openAIClient.isPresent(); + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIParam.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIParam.java new file mode 100644 index 0000000000..b5793e0664 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIParam.java @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.ai; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +public class AIParam { + + /** + * Default values for AI configuration + */ + public static final Boolean DEFAULT_ENABLE_THINKING = false; + public static final Double DEFAULT_TEMPERATURE = 0.1; + public static final Double DEFAULT_TOP_P = 1.0; + public static final Integer DEFAULT_TOP_K = 0; + public static final Integer DEFAULT_MIN_P = 0; + + public static final Integer DEFAULT_BATCH_SIZE_IN_TABLE = 30; +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIStatusResponse.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIStatusResponse.java new file mode 100644 index 0000000000..d5ed6d977b --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/AIStatusResponse.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.ai; + +import lombok.Data; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +@Data +public class AIStatusResponse { + private boolean enabled; + + private boolean available; + + private String model; + + private String baseUrl; + + private boolean apiKeyConfigured; +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/PromptTemplateLoader.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/PromptTemplateLoader.java new file mode 100644 index 0000000000..e36d837239 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ai/PromptTemplateLoader.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.ai; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import javax.annotation.PostConstruct; + +import org.springframework.stereotype.Component; + +import lombok.extern.slf4j.Slf4j; +import lombok.var; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +@Slf4j +@Component +public class PromptTemplateLoader { + private static final String SYSTEM_TEMPLATE_PATH = + "/ai-prompt-template/sensitive_column_recognize_system_prompt.txt"; + private static final String TYPES_PLACEHOLDER = "{sensitiveTypes}"; + private static final String PROMPT_PLACEHOLDER = "{customPrompt}"; + + private String systemTemplate; + + @PostConstruct + public void init() { + try (var inputStream = PromptTemplateLoader.class.getResourceAsStream(SYSTEM_TEMPLATE_PATH)) { + if (Objects.isNull(inputStream)) { + throw new IllegalStateException("AI system prompt template file not found: " + SYSTEM_TEMPLATE_PATH); + } + try (var reader = new java.io.BufferedReader(new java.io.InputStreamReader(inputStream))) { + this.systemTemplate = reader.lines().collect(Collectors.joining(System.lineSeparator())); + } + } catch (Exception e) { + log.error("Failed to load AI system prompt template: {}", e.getMessage(), e); + throw new IllegalStateException("Failed to load AI system prompt template", e); + } + } + + public String buildSystemPrompt(List sensitiveTypes, String customPrompt) { + if (this.systemTemplate == null || this.systemTemplate.isEmpty()) { + throw new IllegalStateException("System prompt template is not available. Check loading status."); + } + + String formattedTypes = (sensitiveTypes == null || sensitiveTypes.isEmpty()) + ? "No specified category." + : String.join(", ", sensitiveTypes); + + String formattedPrompt = (customPrompt == null || customPrompt.trim().isEmpty()) + ? "No supplementary rule." + : customPrompt; + + return this.systemTemplate + .replace(TYPES_PLACEHOLDER, formattedTypes) + .replace(PROMPT_PLACEHOLDER, formattedPrompt); + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ColumnRecognizerFactory.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/factory/ColumnRecognizerFactory.java similarity index 79% rename from server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ColumnRecognizerFactory.java rename to server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/factory/ColumnRecognizerFactory.java index 50cfc35c7a..c38ca4034e 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/ColumnRecognizerFactory.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/factory/ColumnRecognizerFactory.java @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.oceanbase.odc.service.datasecurity; +package com.oceanbase.odc.service.datasecurity.factory; import com.oceanbase.odc.core.shared.constant.ErrorCodes; import com.oceanbase.odc.core.shared.exception.UnsupportedException; import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.recognizer.AIColumnRecognizer; import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; import com.oceanbase.odc.service.datasecurity.recognizer.GroovyColumnRecognizer; import com.oceanbase.odc.service.datasecurity.recognizer.PathColumnRecognizer; @@ -34,16 +35,16 @@ public class ColumnRecognizerFactory { public static ColumnRecognizer create(@NonNull SensitiveRule rule) { switch (rule.getType()) { case REGEX: - return new RegexColumnRecognizer(rule.getDatabaseRegexExpression(), rule.getTableRegexExpression(), - rule.getColumnRegexExpression(), rule.getColumnCommentRegexExpression()); + return new RegexColumnRecognizer(rule); case PATH: - return new PathColumnRecognizer(rule.getPathIncludes(), rule.getPathExcludes()); + return new PathColumnRecognizer(rule); case GROOVY: - return new GroovyColumnRecognizer(rule.getGroovyScript()); + return new GroovyColumnRecognizer(rule); + case AI: + return new AIColumnRecognizer(rule); default: String errorMsg = String.format("Unsupported sensitive rule type: %s", rule.getType().name()); throw new UnsupportedException(ErrorCodes.BadArgument, new Object[] {errorMsg}, errorMsg); } } - } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/factory/ScanningStrategyFactory.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/factory/ScanningStrategyFactory.java new file mode 100644 index 0000000000..46bbba6109 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/factory/ScanningStrategyFactory.java @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.factory; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import org.springframework.stereotype.Component; + +import com.oceanbase.odc.service.datasecurity.model.ScanResult; +import com.oceanbase.odc.service.datasecurity.model.ScanningModeType; +import com.oceanbase.odc.service.datasecurity.strategy.AIOnlyStrategy; +import com.oceanbase.odc.service.datasecurity.strategy.JointRecognitionStrategy; +import com.oceanbase.odc.service.datasecurity.strategy.RulesOnlyStrategy; +import com.oceanbase.odc.service.datasecurity.strategy.ScanningStrategy; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +@Component +public class ScanningStrategyFactory { + + private final Map strategies = new HashMap<>(); + + public ScanningStrategyFactory() { + strategies.put(ScanningModeType.RULES_ONLY, new RulesOnlyStrategy()); + strategies.put(ScanningModeType.JOINT_RECOGNITION, new JointRecognitionStrategy()); + strategies.put(ScanningModeType.AI_ONLY, new AIOnlyStrategy()); + } + + public ScanningStrategy getStrategy(ScanningModeType mode) { + ScanningStrategy strategy = strategies.get(mode); + if (strategy == null) { + return new NoOpStrategy(); + } + return strategy; + } + + private static class NoOpStrategy implements ScanningStrategy { + @Override + public ScanResult scan(DBTableColumn column, + java.util.List basicRecognizers, + java.util.List aiRecognizers) { + return new ScanResult(Optional.empty(), Optional.empty()); + } + + @Override + public Map scanBatch(java.util.List columns, + java.util.List basicRecognizers, + java.util.List aiRecognizers) { + Map results = new HashMap<>(); + for (DBTableColumn column : columns) { + String columnKey = String.format("%s.%s.%s", + column.getSchemaName() != null ? column.getSchemaName() : "unknown_schema", + column.getTableName() != null ? column.getTableName() : "unknown_table", + column.getName() != null ? column.getName() : "unknown_column"); + results.put(columnKey, new ScanResult(Optional.empty(), Optional.empty())); + } + return results; + } + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/DefaultSensitiveType.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/DefaultSensitiveType.java new file mode 100644 index 0000000000..7fd990f5df --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/DefaultSensitiveType.java @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.model; + +import java.util.Optional; + +/** + * 13 Default Sensitivity Types Identified by AI + * + * @author fenyf + * @date 2025/8/1 + */ +public enum DefaultSensitiveType { + + PERSONAL_NAME_CHINESE("${com.oceanbase.odc.builtin-resource.masking-algorithm.personal-name-chinese.name}"), + + PERSONAL_NAME_ALPHABET("${com.oceanbase.odc.builtin-resource.masking-algorithm.personal-name-alphabet.name}"), + + NICKNAME("${com.oceanbase.odc.builtin-resource.masking-algorithm.nickname.name}"), + + EMAIL("${com.oceanbase.odc.builtin-resource.masking-algorithm.email.name}"), + + ADDRESS("${com.oceanbase.odc.builtin-resource.masking-algorithm.address.name}"), + + PHONE_NUMBER("${com.oceanbase.odc.builtin-resource.masking-algorithm.phone-number.name}"), + + FIXED_LINE_PHONE_NUMBER("${com.oceanbase.odc.builtin-resource.masking-algorithm.fixed-line-phone-number.name}"), + + CERTIFICATE_NUMBER("${com.oceanbase.odc.builtin-resource.masking-algorithm.certificate-number.name}"), + + BANK_CARD_NUMBER("${com.oceanbase.odc.builtin-resource.masking-algorithm.bank-card-number.name}"), + + LICENSE_PLATE_NUMBER("${com.oceanbase.odc.builtin-resource.masking-algorithm.license-plate-number.name}"), + + DEVICE_ID("${com.oceanbase.odc.builtin-resource.masking-algorithm.device-id.name}"), + + IP("${com.oceanbase.odc.builtin-resource.masking-algorithm.ip.name}"), + + MAC("${com.oceanbase.odc.builtin-resource.masking-algorithm.mac.name}"); + + private final String algorithmName; + + DefaultSensitiveType(String algorithmName) { + this.algorithmName = algorithmName; + } + + public String getAlgorithmName() { + return algorithmName; + } + + public static boolean isDefaultType(String sensitiveType) { + return findBestMatch(sensitiveType).isPresent(); + } + + /** + * Retrieve the corresponding algorithm name based on the name of the sensitive type. + */ + public static Optional getAlgorithmNameBySensitiveType(String sensitiveType) { + return findBestMatch(sensitiveType).map(DefaultSensitiveType::getAlgorithmName); + } + + public static Optional getByDisplayName(String sensitiveType) { + return findBestMatch(sensitiveType); + } + + /** + * Match sensitive type + */ + public static Optional findBestMatch(String sensitiveType) { + if (sensitiveType == null || sensitiveType.trim().isEmpty()) { + return Optional.empty(); + } + + String normalized = sensitiveType.toLowerCase().trim(); + + for (DefaultSensitiveType type : values()) { + if (type.name().toLowerCase().equals(normalized)) { + return Optional.of(type); + } + } + + for (DefaultSensitiveType type : values()) { + String hyphenFormat = type.name().toLowerCase().replace("_", "-"); + if (hyphenFormat.equals(normalized)) { + return Optional.of(type); + } + } + + for (DefaultSensitiveType type : values()) { + String enumName = type.name().toLowerCase(); + String hyphenFormat = enumName.replace("_", "-"); + if (enumName.contains(normalized) || normalized.contains(enumName.replace("_", "")) || + hyphenFormat.contains(normalized) || normalized.contains(hyphenFormat.replace("-", ""))) { + return Optional.of(type); + } + } + + return Optional.empty(); + } + +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/RecognitionResult.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/RecognitionResult.java new file mode 100644 index 0000000000..e3c004a132 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/RecognitionResult.java @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.model; + +import lombok.Builder; +import lombok.Data; + +@Data +@Builder +public class RecognitionResult { + // 基础信息 + private boolean matched; + private Long matchedRuleId; + private SensitiveLevel level; + private SensitiveRuleType sourceRuleType; + + // AI 规则 + private String sensitiveType; // AI 判断出的具体敏感类型 +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/ScanResult.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/ScanResult.java new file mode 100644 index 0000000000..9fe6da5253 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/ScanResult.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.model; + +import java.util.Optional; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * @author fenyf + * @date 2025/7/18 17:52 + */ +@Getter +@AllArgsConstructor +public class ScanResult { + private final Optional basicRuleResult; + private final Optional aiRuleResult; + + public Optional getFinalResult(ScanningModeType scanningMode) { + switch (scanningMode) { + case RULES_ONLY: + return basicRuleResult; + case AI_ONLY: + return aiRuleResult; + case JOINT_RECOGNITION: + return basicRuleResult.isPresent() ? basicRuleResult : aiRuleResult; + default: + return Optional.empty(); + } + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/ScanningModeType.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/ScanningModeType.java new file mode 100644 index 0000000000..5355078bf1 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/ScanningModeType.java @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.model; + +/** + * @author fenyf + * @date 2025/7/18 17:52 + */ +public enum ScanningModeType { + + RULES_ONLY, + + JOINT_RECOGNITION, + + AI_ONLY; +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveColumnScanningReq.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveColumnScanningReq.java index fd77c7f642..ba36f1c087 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveColumnScanningReq.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveColumnScanningReq.java @@ -35,4 +35,6 @@ public class SensitiveColumnScanningReq { @NotNull private Boolean allSensitiveRules; private List sensitiveRuleIds; + @NotNull + private ScanningModeType scanningMode = ScanningModeType.JOINT_RECOGNITION; } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveColumnScanningTaskInfo.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveColumnScanningTaskInfo.java index 9c38d88b29..68cfb2f262 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveColumnScanningTaskInfo.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveColumnScanningTaskInfo.java @@ -41,6 +41,7 @@ public class SensitiveColumnScanningTaskInfo { private Date completeTime; private ErrorCode errorCode; private String errorMsg; + private volatile boolean cancelled = false; public SensitiveColumnScanningTaskInfo(@NonNull String taskId, @NonNull Long projectId, @NonNull Integer allTableCount) { @@ -81,14 +82,23 @@ public synchronized void setErrorMsg(String msg) { this.errorMsg = msg; } + public synchronized void setCancelled(boolean cancelled) { + this.cancelled = cancelled; + } + + public boolean isCancelled() { + return this.cancelled; + } + public enum ScanningTaskStatus { CREATED, RUNNING, SUCCESS, - FAILED; + FAILED, + CANCELLED; public boolean isCompleted() { - return this == SUCCESS || this == FAILED; + return this == SUCCESS || this == FAILED || this == CANCELLED; } } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveRule.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveRule.java index 770038c56a..b812eb59d5 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveRule.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveRule.java @@ -53,7 +53,6 @@ public class SensitiveRule implements SecurityResource, SingleOrganizationResour @JsonProperty(access = Access.READ_ONLY) private Long projectId; - @NotNull private SensitiveRuleType type; private String databaseRegexExpression; @@ -70,6 +69,10 @@ public class SensitiveRule implements SecurityResource, SingleOrganizationResour private List pathExcludes = new ArrayList<>(); + private List aiSensitiveTypes = new ArrayList<>(); + + private String aiCustomPrompt; + @NotNull private Long maskingAlgorithmId; diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveRuleType.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveRuleType.java index 0673cbbf95..54ba6a4be0 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveRuleType.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SensitiveRuleType.java @@ -33,5 +33,10 @@ public enum SensitiveRuleType { /** * Path expression fuzzy match */ - PATH + PATH, + + /** + * AI-based sensitive data detection + */ + AI } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SingleTableScanReq.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SingleTableScanReq.java new file mode 100644 index 0000000000..04dbf36e76 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/model/SingleTableScanReq.java @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.model; + +import javax.validation.constraints.NotBlank; +import javax.validation.constraints.NotNull; + +import lombok.Data; + +/** + * Single-table sensitive column scan request + * + * @author fenyf + * @date 2025/8/18 17:52 + */ +@Data +public class SingleTableScanReq { + + @NotNull + private Long databaseId; + + @NotBlank + private String tableName; + + @NotNull + private ScanningModeType scanningMode = ScanningModeType.JOINT_RECOGNITION; + +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/AIColumnRecognizer.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/AIColumnRecognizer.java new file mode 100644 index 0000000000..9870d29d00 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/AIColumnRecognizer.java @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.recognizer; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.Lists; +import com.oceanbase.odc.core.shared.constant.ErrorCodes; +import com.oceanbase.odc.core.shared.exception.BadRequestException; +import com.oceanbase.odc.service.common.util.SpringContextUtil; +import com.oceanbase.odc.service.datasecurity.ai.AIInferenceService; +import com.oceanbase.odc.service.datasecurity.ai.AIParam; +import com.oceanbase.odc.service.datasecurity.ai.PromptTemplateLoader; +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; +import com.openai.models.chat.completions.ChatCompletion; + +import lombok.Data; +import lombok.extern.slf4j.Slf4j; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +@Slf4j +public class AIColumnRecognizer implements ColumnRecognizer { + + private final SensitiveRule aiRule; + private static final int BATCH_SIZE = AIParam.DEFAULT_BATCH_SIZE_IN_TABLE; + private static final ObjectMapper objectMapper = new ObjectMapper(); + private static final Pattern JSON_PATTERN = Pattern + .compile("(?s)```json\\s*([\\{\\[].*[\\}\\]])\\s*```|([\\{\\[].*[\\}\\]])"); + + public AIColumnRecognizer(SensitiveRule rule) { + this.aiRule = rule; + } + + @Override + public Optional recognize(DBTableColumn column) { + Map> batchResult = recognizeBatch(Collections.singletonList(column)); + String columnKey = getColumnKey(column); + return batchResult.getOrDefault(columnKey, Optional.empty()); + } + + /** + * If the data in the table columns is too large, scan them in batches. + * + * @param columns list of columns {@link DBTableColumn} + * @return + */ + @Override + public Map> recognizeBatch(List columns) { + if (columns == null || columns.isEmpty()) { + return Collections.emptyMap(); + } + PromptTemplateLoader promptTemplateLoader = SpringContextUtil.getBean(PromptTemplateLoader.class); + AIInferenceService aiService = SpringContextUtil.getBean(AIInferenceService.class); + + Map> finalAiResults = new HashMap<>(); + + if (columns.size() > BATCH_SIZE) { + List> batches = Lists.partition(columns, BATCH_SIZE); + try { + for (List batch : batches) { + processBatch(batch, promptTemplateLoader, aiService, finalAiResults); + } + } catch (BadRequestException e) { + throw e; + } catch (Exception e) { + log.error("Failed to process AI column recognition batch", e); + return finalAiResults; + } + } else { + try { + processBatch(columns, promptTemplateLoader, aiService, finalAiResults); + } catch (BadRequestException e) { + throw e; + } catch (Exception e) { + log.error("Failed to process AI column recognition", e); + return finalAiResults; + } + } + return finalAiResults; + } + + /** + * Process a single batch of column data + */ + private void processBatch(List batch, PromptTemplateLoader promptTemplateLoader, + AIInferenceService aiService, Map> finalAiResults) throws IOException { + String systemPrompt = promptTemplateLoader.buildSystemPrompt(aiRule.getAiSensitiveTypes(), + aiRule.getAiCustomPrompt()); + String userPrompt = buildUserPrompt(batch); + ChatCompletion completion = aiService.chat(systemPrompt, userPrompt); + String rawContent = completion.choices().get(0).message().content().orElse("[]"); + + Matcher matcher = JSON_PATTERN.matcher(rawContent); + String jsonArrayResponse = null; + if (matcher.find()) { + jsonArrayResponse = Optional.ofNullable(matcher.group(1)).orElse(matcher.group(2)); + } + + if (jsonArrayResponse == null) { + throw new BadRequestException(ErrorCodes.AIResponseFormatError, + new Object[] {"No valid JSON array found in AI response"}, + "AI response does not contain valid JSON format: " + rawContent); + } + + List batchResults; + try { + batchResults = objectMapper.readValue(jsonArrayResponse, + new TypeReference>() {}); + } catch (Exception e) { + throw new BadRequestException(ErrorCodes.AIResponseFormatError, + new Object[] {"Failed to parse JSON: " + e.getMessage()}, + "Failed to parse AI response JSON: " + jsonArrayResponse, e); + } + int maxIndex = Math.min(batch.size(), batchResults.size()); + for (int i = 0; i < maxIndex; i++) { + DBTableColumn column = batch.get(i); + String columnKey = getColumnKey(column); + AiResponseDto dto = batchResults.get(i); + + if (dto.isSensitive()) { + RecognitionResult result = RecognitionResult.builder() + .matched(true) + .matchedRuleId(this.aiRule.getId()) + .level(dto.getRiskLevel()) + .sourceRuleType(SensitiveRuleType.AI) + .sensitiveType(dto.getSensitiveCategory()) + .build(); + finalAiResults.put(columnKey, Optional.of(result)); + } else { + finalAiResults.put(columnKey, Optional.empty()); + } + } + if (batchResults.size() != batch.size()) { + log.warn("AI response count ({}) does not match input column count ({})", + batchResults.size(), batch.size()); + } + } + + /** + * Build user prompt (JSON array of column data) + */ + private String buildUserPrompt(List batch) throws IOException { + List> columnMetadataList = batch.stream().map(c -> { + Map meta = new HashMap<>(); + meta.put("schemaName", c.getSchemaName()); + meta.put("tableName", c.getTableName()); + meta.put("columnName", c.getName()); + meta.put("comment", c.getComment()); + meta.put("dataType", c.getTypeName()); + return meta; + }).collect(Collectors.toList()); + return objectMapper.writeValueAsString(columnMetadataList); + } + + // Inner class for holding AI response JSON data + @Data + private static class AiResponseDto { + private boolean sensitive; + private SensitiveLevel riskLevel; + private String sensitiveCategory; + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/ColumnRecognizer.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/ColumnRecognizer.java index 9a301990bf..fc3e64fb02 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/ColumnRecognizer.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/ColumnRecognizer.java @@ -15,11 +15,18 @@ */ package com.oceanbase.odc.service.datasecurity.recognizer; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; import com.oceanbase.tools.dbbrowser.model.DBTableColumn; /** - * @author gaoda.xy - * @date 2023/5/30 10:32 + * @author fenyf + * @date 2025/8/10 12:41 */ public interface ColumnRecognizer { @@ -29,6 +36,38 @@ public interface ColumnRecognizer { * @param column column {@link DBTableColumn} * @return recognizing result */ - boolean recognize(DBTableColumn column); + Optional recognize(DBTableColumn column); + + /** + * Batch recognizing the columns in database + * + * @param columns list of columns {@link DBTableColumn} + * @return map of recognizing results, key is column identifier, value is recognizing result + */ + default Map> recognizeBatch(List columns) { + if (columns == null || columns.isEmpty()) { + return Collections.emptyMap(); + } + Map> results = new HashMap<>(); + for (DBTableColumn column : columns) { + String columnKey = getColumnKey(column); + Optional result = recognize(column); + results.put(columnKey, result); + } + return results; + } + + /** + * Generate a unique key for the column + * + * @param column column {@link DBTableColumn} + * @return unique column key + */ + default String getColumnKey(DBTableColumn column) { + return String.format("%s.%s.%s", + column.getSchemaName() != null ? column.getSchemaName() : "unknown_schema", + column.getTableName() != null ? column.getTableName() : "unknown_table", + column.getName() != null ? column.getName() : "unknown_column"); + } } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/GroovyColumnRecognizer.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/GroovyColumnRecognizer.java index 9cf782d9a7..ceaa3f7e47 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/GroovyColumnRecognizer.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/GroovyColumnRecognizer.java @@ -15,8 +15,13 @@ */ package com.oceanbase.odc.service.datasecurity.recognizer; +import java.util.Optional; + import org.codehaus.groovy.control.CompilerConfiguration; +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; import com.oceanbase.odc.service.datasecurity.util.SecureAstCustomizerUtil; import com.oceanbase.tools.dbbrowser.model.DBTableColumn; @@ -32,26 +37,39 @@ */ public class GroovyColumnRecognizer implements ColumnRecognizer { + private final SensitiveRule rule; private final Script script; private static final String COLUMN_KEYWORD = "column"; - public GroovyColumnRecognizer(String groovyScript) { + public GroovyColumnRecognizer(SensitiveRule rule) { + this.rule = rule; CompilerConfiguration config = new CompilerConfiguration(); config.addCompilationCustomizers(SecureAstCustomizerUtil.buildSecureASTCustomizer()); GroovyShell shell = new GroovyShell(config); - this.script = shell.parse(groovyScript); + this.script = shell.parse(rule.getGroovyScript()); } @Override - public boolean recognize(DBTableColumn column) { + public Optional recognize(DBTableColumn column) { try { GroovyColumnMeta groovyColumnMeta = new GroovyColumnMeta(column); Binding binding = new Binding(); binding.setVariable(COLUMN_KEYWORD, groovyColumnMeta); script.setBinding(binding); - return (boolean) script.run(); + + boolean matched = (boolean) script.run(); + if (matched) { + RecognitionResult result = RecognitionResult.builder() + .matched(true) + .matchedRuleId(this.rule.getId()) + .level(this.rule.getLevel()) + .sourceRuleType(SensitiveRuleType.GROOVY) + .build(); + return Optional.of(result); + } + return Optional.empty(); } catch (Exception e) { - return false; + return Optional.empty(); } } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/PathColumnRecognizer.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/PathColumnRecognizer.java index 3669113ead..430bb6f274 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/PathColumnRecognizer.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/PathColumnRecognizer.java @@ -17,12 +17,16 @@ import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.IOCase; import com.oceanbase.odc.core.shared.PreConditions; +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; import com.oceanbase.tools.dbbrowser.model.DBTableColumn; /** @@ -31,33 +35,41 @@ */ public class PathColumnRecognizer implements ColumnRecognizer { + private final SensitiveRule rule; private final List pathIncludeMatchers; private final List pathExcludeMatchers; - public PathColumnRecognizer(List pathIncludes, List pathExcludes) { - pathIncludeMatchers = pathIncludes.stream().map(FieldPathMatcher::new).collect(Collectors.toList()); - pathExcludeMatchers = pathExcludes.stream().map(FieldPathMatcher::new).collect(Collectors.toList()); + public PathColumnRecognizer(SensitiveRule rule) { + this.rule = rule; + pathIncludeMatchers = rule.getPathIncludes().stream().map(FieldPathMatcher::new).collect(Collectors.toList()); + pathExcludeMatchers = rule.getPathExcludes().stream().map(FieldPathMatcher::new).collect(Collectors.toList()); } @Override - public boolean recognize(DBTableColumn column) { + public Optional recognize(DBTableColumn column) { try { String schemaName = column.getSchemaName(); String tableName = column.getTableName(); String columnName = column.getName(); for (FieldPathMatcher matcher : pathExcludeMatchers) { if (matcher.match(schemaName, tableName, columnName)) { - return false; + return Optional.empty(); } } for (FieldPathMatcher matcher : pathIncludeMatchers) { if (matcher.match(schemaName, tableName, columnName)) { - return true; + RecognitionResult result = RecognitionResult.builder() + .matched(true) + .matchedRuleId(this.rule.getId()) + .level(this.rule.getLevel()) + .sourceRuleType(SensitiveRuleType.PATH) + .build(); + return Optional.of(result); } } - return false; + return Optional.empty(); } catch (Exception e) { - return false; + return Optional.empty(); } } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/RegexColumnRecognizer.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/RegexColumnRecognizer.java index 02b15d1152..117d4b04cf 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/RegexColumnRecognizer.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/recognizer/RegexColumnRecognizer.java @@ -15,9 +15,13 @@ */ package com.oceanbase.odc.service.datasecurity.recognizer; +import java.util.Optional; import java.util.regex.Pattern; import com.oceanbase.odc.common.util.StringUtils; +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; import com.oceanbase.tools.dbbrowser.model.DBTableColumn; import lombok.NonNull; @@ -28,6 +32,7 @@ */ public class RegexColumnRecognizer implements ColumnRecognizer { + private final SensitiveRule rule; private final Pattern databasePattern; private final Pattern tablePattern; private final Pattern columnPattern; @@ -35,35 +40,50 @@ public class RegexColumnRecognizer implements ColumnRecognizer { private static final long MATCH_TIMEOUT_MILLIS = 100L; - public RegexColumnRecognizer(String databaseRegex, String tableRegex, String columnRegex, String commentRegex) { - databasePattern = StringUtils.isNotBlank(databaseRegex) ? Pattern.compile(databaseRegex) : null; - tablePattern = StringUtils.isNotBlank(tableRegex) ? Pattern.compile(tableRegex) : null; - columnPattern = StringUtils.isNotBlank(columnRegex) ? Pattern.compile(columnRegex) : null; - columnCommentPattern = StringUtils.isNotBlank(commentRegex) ? Pattern.compile(commentRegex) : null; + public RegexColumnRecognizer(SensitiveRule rule) { + this.rule = rule; + databasePattern = StringUtils.isNotBlank(rule.getDatabaseRegexExpression()) + ? Pattern.compile(rule.getDatabaseRegexExpression()) + : null; + tablePattern = + StringUtils.isNotBlank(rule.getTableRegexExpression()) ? Pattern.compile(rule.getTableRegexExpression()) + : null; + columnPattern = StringUtils.isNotBlank(rule.getColumnRegexExpression()) + ? Pattern.compile(rule.getColumnRegexExpression()) + : null; + columnCommentPattern = StringUtils.isNotBlank(rule.getColumnCommentRegexExpression()) + ? Pattern.compile(rule.getColumnCommentRegexExpression()) + : null; } @Override - public boolean recognize(DBTableColumn column) { + public Optional recognize(DBTableColumn column) { try { if (databasePattern != null && !databasePattern .matcher(new TimeoutCharSequence(column.getSchemaName(), getTimeoutMillis())).matches()) { - return false; + return Optional.empty(); } if (tablePattern != null && !tablePattern .matcher(new TimeoutCharSequence(column.getTableName(), getTimeoutMillis())).matches()) { - return false; + return Optional.empty(); } if (columnPattern != null && !columnPattern .matcher(new TimeoutCharSequence(column.getName(), getTimeoutMillis())).matches()) { - return false; + return Optional.empty(); } if (columnCommentPattern != null && !columnCommentPattern .matcher(new TimeoutCharSequence(column.getComment(), getTimeoutMillis())).matches()) { - return false; + return Optional.empty(); } - return true; + RecognitionResult result = RecognitionResult.builder() + .matched(true) + .matchedRuleId(this.rule.getId()) + .level(this.rule.getLevel()) + .sourceRuleType(SensitiveRuleType.REGEX) + .build(); + return Optional.of(result); } catch (Exception e) { - return false; + return Optional.empty(); } } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/AIOnlyStrategy.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/AIOnlyStrategy.java new file mode 100644 index 0000000000..78d6bb0f61 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/AIOnlyStrategy.java @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.strategy; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.ScanResult; +import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +public class AIOnlyStrategy extends AbstractScanningStrategy { + + @Override + public ScanResult scan(DBTableColumn column, List basicRecognizers, + List aiRecognizers) { + Optional aiResult = findFirstMatch(aiRecognizers, column); + return new ScanResult(Optional.empty(), aiResult); + } + + @Override + public Map scanBatch(List columns, List basicRecognizers, + List aiRecognizers) { + Map> aiResults = findAllFirstMatches(aiRecognizers, columns); + Map results = new HashMap<>(); + + for (DBTableColumn column : columns) { + String columnKey = getColumnKey(column); + Optional aiResult = aiResults.getOrDefault(columnKey, Optional.empty()); + results.put(columnKey, new ScanResult(Optional.empty(), aiResult)); + } + + return results; + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/AbstractScanningStrategy.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/AbstractScanningStrategy.java new file mode 100644 index 0000000000..472e2c7651 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/AbstractScanningStrategy.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.strategy; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +public abstract class AbstractScanningStrategy implements ScanningStrategy { + protected Optional findFirstMatch(List recognizers, DBTableColumn column) { + for (ColumnRecognizer recognizer : recognizers) { + Optional result = recognizer.recognize(column); + if (result.isPresent()) { + return result; + } + } + return Optional.empty(); + } + + protected Map> findAllFirstMatches(List recognizers, + List columns) { + if (recognizers.isEmpty() || columns.isEmpty()) { + return createEmptyResultMap(columns); + } + if (recognizers.size() == 1) { + ColumnRecognizer recognizer = recognizers.get(0); + return recognizer.recognizeBatch(columns); + } + + Map> results = new HashMap<>(); + for (DBTableColumn column : columns) { + String columnKey = getColumnKey(column); + Optional result = findFirstMatch(recognizers, column); + results.put(columnKey, result); + } + return results; + } + + protected Map> createEmptyResultMap(List columns) { + Map> results = new HashMap<>(); + for (DBTableColumn column : columns) { + String columnKey = getColumnKey(column); + results.put(columnKey, Optional.empty()); + } + return results; + } + + protected String getColumnKey(DBTableColumn column) { + return String.format("%s.%s.%s", + column.getSchemaName() != null ? column.getSchemaName() : "unknown_schema", + column.getTableName() != null ? column.getTableName() : "unknown_table", + column.getName() != null ? column.getName() : "unknown_column"); + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/JointRecognitionStrategy.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/JointRecognitionStrategy.java new file mode 100644 index 0000000000..6156f6531b --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/JointRecognitionStrategy.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.strategy; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.ScanResult; +import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +public class JointRecognitionStrategy extends AbstractScanningStrategy { + + @Override + public ScanResult scan(DBTableColumn column, List basicRecognizers, + List aiRecognizers) { + Optional basicResult = findFirstMatch(basicRecognizers, column); + if (basicResult.isPresent()) { + return new ScanResult(basicResult, Optional.empty()); + } + Optional aiResult = findFirstMatch(aiRecognizers, column); + return new ScanResult(Optional.empty(), aiResult); + } + + @Override + public Map scanBatch(List columns, List basicRecognizers, + List aiRecognizers) { + Map> basicResults = findAllFirstMatches(basicRecognizers, columns); + + List remainingColumns = new ArrayList<>(); + for (DBTableColumn column : columns) { + String columnKey = getColumnKey(column); + Optional basicResult = basicResults.getOrDefault(columnKey, Optional.empty()); + if (!basicResult.isPresent()) { + remainingColumns.add(column); + } + } + Map> aiResults = findAllFirstMatches(aiRecognizers, remainingColumns); + + Map results = new HashMap<>(); + for (DBTableColumn column : columns) { + String columnKey = getColumnKey(column); + Optional basicResult = basicResults.getOrDefault(columnKey, Optional.empty()); + + if (basicResult.isPresent()) { + results.put(columnKey, new ScanResult(basicResult, Optional.empty())); + } else { + Optional aiResult = aiResults.getOrDefault(columnKey, Optional.empty()); + results.put(columnKey, new ScanResult(Optional.empty(), aiResult)); + } + } + + return results; + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/RulesOnlyStrategy.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/RulesOnlyStrategy.java new file mode 100644 index 0000000000..e209e31ec8 --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/RulesOnlyStrategy.java @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.strategy; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.ScanResult; +import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +public class RulesOnlyStrategy extends AbstractScanningStrategy { + + @Override + public ScanResult scan(DBTableColumn column, List basicRecognizers, + List aiRecognizers) { + Optional basicResult = findFirstMatch(basicRecognizers, column); + return new ScanResult(basicResult, Optional.empty()); + } + + @Override + public Map scanBatch(List columns, List basicRecognizers, + List aiRecognizers) { + Map> basicResults = findAllFirstMatches(basicRecognizers, columns); + Map results = new HashMap<>(); + + for (DBTableColumn column : columns) { + String columnKey = getColumnKey(column); + Optional basicResult = basicResults.getOrDefault(columnKey, Optional.empty()); + results.put(columnKey, new ScanResult(basicResult, Optional.empty())); + } + + return results; + } +} diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/ScanningStrategy.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/ScanningStrategy.java new file mode 100644 index 0000000000..b156e92c4f --- /dev/null +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/strategy/ScanningStrategy.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.strategy; + +import java.util.List; +import java.util.Map; + +import com.oceanbase.odc.service.datasecurity.model.ScanResult; +import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +/** + * @author fenyf + * @date 2025/8/10 12:41 + */ +public interface ScanningStrategy { + + ScanResult scan(DBTableColumn column, List basicRecognizers, + List aiRecognizers); + + Map scanBatch(List columns, List basicRecognizers, + List aiRecognizers); +} diff --git a/server/odc-service/src/main/resources/ai-prompt-template/sensitive_column_recognize_system_prompt.txt b/server/odc-service/src/main/resources/ai-prompt-template/sensitive_column_recognize_system_prompt.txt new file mode 100644 index 0000000000..aeec7f1e50 --- /dev/null +++ b/server/odc-service/src/main/resources/ai-prompt-template/sensitive_column_recognize_system_prompt.txt @@ -0,0 +1,81 @@ +You are a data classification expert specializing in data privacy and security. Your task is to determine if a database column's metadata likely belongs to sensitive data . + +Analyze the JSON array of database columns that I will provide in your next message, based on the following criteria: + +**1. Input data structure:** +You will receive a JSON array, where each object represents a database column. Each column originates from a table in the actual backend business. The column information includes schemaName, tableName, columnName, comment, and dataType. + +**2. Sensitive Categories Check Rule:** +If there is "No specified category", you should independently determine whether it is sensitive. In this case, the sensitiveCategory should return "null". If there are some specified categories, you MUST and can ONLY select categories from the following "Sensitive Categories to Check" list as the identification result. If the sensitive columns you identified do not fall within the provided categories, simply classify them as non-sensitive. +Sensitive Categories to Check: {sensitiveTypes} + +**3. Supplementary Rule:** +This task may involve a "Supplementary Recognition Rule". If it exists, please ensure that the content of "Supplementary Recognition Rule" is given the highest priority. If not, "No supplementary rule" will be displayed here. Please ignore this item. +Supplementary Recognition Rule: {customPrompt} + +**CRITICAL CONSTRAINT - STRICTLY ENFORCE:** + +1. You are STRICTLY FORBIDDEN from outputting any sensitive type that is NOT explicitly listed in the "Sensitive Categories to Check" section above. +If a column is sensitive but doesn't match ANY of the specified categories, you MUST set "sensitiveCategory" to null. +Violating this constraint will result in system errors. + +2. Based on all the information above, return a JSON array where each object corresponds to a column from the input array in the SAME order. +For each column, respond with a JSON object in the following format ONLY. Do not add any other text or explanations. +The final output should be a single, valid JSON array. Please do not output Markdown code blocks, such as ```json. + +3. The number of columns in your output MUST be exactly equal to the number of columns in the input. +Do not omit any columns, and do not add extra columns. Each input column must have a corresponding output object in the same order. + +Format for each object in the output array: +{ + "sensitive": boolean, + "riskLevel": "HIGH" | "MEDIUM" | "LOW", + "sensitiveCategory": string (MUST be one of the categories listed in "Sensitive Categories to Check" section, or null) +} + +**Example:** +Assuming the specified categories include "address, phone-number, email, ip and fixed-line-phone-number" and input json array is: +[ + { + "schemaName": "education_db", + "tableName": "students", + "columnName": "student_address" + "comment": "The student's home address" + "dataType": "TEXT" + }, + { + "schemaName": "education_db", + "tableName": "students", + "columnName": "student_home_phone" + "comment": "The student's fixed home phone" + "dataType": "VARCHAR(20)" + }, + { + "schemaName": "education_db", + "tableName": "students", + "columnName": "math_grade" + "comment": "The students' math exam scores" + "dataType": "TINYINT" + } +] + +The expected output format for 3 columns: +[ + { + "sensitive": true, + "riskLevel": "HIGH", + "sensitiveCategory": "address" + }, + { + "sensitive": true, + "riskLevel": "MEDIUM", + "sensitiveCategory": "fixed-line-phone-number" + }, + { + "sensitive": false, + "riskLevel": "LOW", + "sensitiveCategory": null + } +] + + diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/ai/AIConfigTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/ai/AIConfigTest.java new file mode 100644 index 0000000000..7788d6cc77 --- /dev/null +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/ai/AIConfigTest.java @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.ai; + +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.test.util.ReflectionTestUtils; + +import com.oceanbase.odc.core.shared.exception.BadRequestException; +import com.openai.client.OpenAIClient; +import com.openai.core.JsonValue; + +@RunWith(MockitoJUnitRunner.class) +public class AIConfigTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private AIConfig aiConfig; + + @Before + public void setUp() { + aiConfig = new AIConfig(); + } + + @Test + public void test_isAIAvailable_enabledAndApiKeySet_returnsTrue() { + // Given + ReflectionTestUtils.setField(aiConfig, "enabled", true); + ReflectionTestUtils.setField(aiConfig, "apiKey", "test-api-key"); + + // When + boolean result = aiConfig.isAIAvailable(); + + // Then + Assert.assertTrue(result); + } + + @Test + public void test_isAIAvailable_disabledWithApiKey_returnsFalse() { + // Given + ReflectionTestUtils.setField(aiConfig, "enabled", false); + ReflectionTestUtils.setField(aiConfig, "apiKey", "test-api-key"); + + // When + boolean result = aiConfig.isAIAvailable(); + + // Then + Assert.assertFalse(result); + } + + @Test + public void test_isAIAvailable_enabledWithoutApiKey_returnsFalse() { + // Given + ReflectionTestUtils.setField(aiConfig, "enabled", true); + ReflectionTestUtils.setField(aiConfig, "apiKey", ""); + + // When + boolean result = aiConfig.isAIAvailable(); + + // Then + Assert.assertFalse(result); + } + + @Test + public void test_isAIAvailable_enabledWithNullApiKey_returnsFalse() { + // Given + ReflectionTestUtils.setField(aiConfig, "enabled", true); + ReflectionTestUtils.setField(aiConfig, "apiKey", null); + + // When + boolean result = aiConfig.isAIAvailable(); + + // Then + Assert.assertFalse(result); + } + + @Test + public void test_isAIAvailable_enabledWithWhitespaceApiKey_returnsFalse() { + // Given + ReflectionTestUtils.setField(aiConfig, "enabled", true); + ReflectionTestUtils.setField(aiConfig, "apiKey", " "); + + // When + boolean result = aiConfig.isAIAvailable(); + + // Then + Assert.assertFalse(result); + } + + @Test + public void test_loadAdditionalParams_defaultValues_returnsCorrectMap() { + // Given - using default values + + // When + Map params = aiConfig.loadAdditionalParams(); + + // Then + Assert.assertNotNull(params); + Assert.assertEquals(3, params.size()); + Assert.assertTrue(params.containsKey("enable_thinking")); + Assert.assertTrue(params.containsKey("top_k")); + Assert.assertTrue(params.containsKey("min_p")); + } + + @Test + public void test_loadAdditionalParams_customValues_returnsCorrectMap() { + // Given + ReflectionTestUtils.setField(aiConfig, "enableThinking", false); + ReflectionTestUtils.setField(aiConfig, "topK", 50); + ReflectionTestUtils.setField(aiConfig, "minP", 10); + + // When + Map params = aiConfig.loadAdditionalParams(); + + // Then + Assert.assertNotNull(params); + Assert.assertEquals(3, params.size()); + Assert.assertTrue(params.containsKey("enable_thinking")); + Assert.assertTrue(params.containsKey("top_k")); + Assert.assertTrue(params.containsKey("min_p")); + } + + @Test + public void test_openAIClient_validApiKey_returnsClient() { + // Given + ReflectionTestUtils.setField(aiConfig, "apiKey", "test-api-key"); + ReflectionTestUtils.setField(aiConfig, "baseUrl", "https://api.openai.com"); + + // When + OpenAIClient client = aiConfig.openAIClient(); + + // Then + Assert.assertNotNull(client); + } + + @Test + public void test_openAIClient_nullApiKey_throwsException() { + // Given + ReflectionTestUtils.setField(aiConfig, "apiKey", null); + thrown.expect(BadRequestException.class); + thrown.expectMessage("API key is not configured"); + + // When + aiConfig.openAIClient(); + } + + @Test + public void test_openAIClient_emptyApiKey_throwsException() { + // Given + ReflectionTestUtils.setField(aiConfig, "apiKey", ""); + thrown.expect(BadRequestException.class); + thrown.expectMessage("API key is not configured"); + + // When + aiConfig.openAIClient(); + } + + @Test + public void test_openAIClient_whitespaceApiKey_throwsException() { + // Given + ReflectionTestUtils.setField(aiConfig, "apiKey", " "); + thrown.expect(BadRequestException.class); + thrown.expectMessage("API key is not configured"); + + // When + aiConfig.openAIClient(); + } + + @Test + public void test_gettersAndSetters_workCorrectly() { + // Test enabled + aiConfig.setEnabled(true); + Assert.assertTrue(aiConfig.isEnabled()); + + // Test apiKey + aiConfig.setApiKey("test-key"); + Assert.assertEquals("test-key", aiConfig.getApiKey()); + + // Test baseUrl + aiConfig.setBaseUrl("https://test.com"); + Assert.assertEquals("https://test.com", aiConfig.getBaseUrl()); + + // Test model + aiConfig.setModel("gpt-4"); + Assert.assertEquals("gpt-4", aiConfig.getModel()); + + // Test temperature + aiConfig.setTemperature(0.8); + Assert.assertEquals(Double.valueOf(0.8), aiConfig.getTemperature()); + + // Test topP + aiConfig.setTopP(0.9); + Assert.assertEquals(Double.valueOf(0.9), aiConfig.getTopP()); + + // Test enableThinking + aiConfig.setEnableThinking(false); + Assert.assertFalse(aiConfig.getEnableThinking()); + + // Test topK + aiConfig.setTopK(100); + Assert.assertEquals(Integer.valueOf(100), aiConfig.getTopK()); + + // Test minP + aiConfig.setMinP(20); + Assert.assertEquals(Integer.valueOf(20), aiConfig.getMinP()); + } +} diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/ai/AIInferenceServiceTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/ai/AIInferenceServiceTest.java new file mode 100644 index 0000000000..977827bc26 --- /dev/null +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/ai/AIInferenceServiceTest.java @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.ai; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import com.oceanbase.odc.core.shared.exception.BadRequestException; +import com.openai.client.OpenAIClient; +import com.openai.core.JsonValue; +import com.openai.models.chat.completions.ChatCompletion; + +@RunWith(MockitoJUnitRunner.class) +public class AIInferenceServiceTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Mock + private AIConfig aiConfig; + + @Mock + private OpenAIClient openAIClient; + + @Mock + private ChatCompletion chatCompletion; + + private AIInferenceService aiInferenceService; + + @Before + public void setUp() { + // Setup mock for openAIClient.chat().completions().create() call chain + } + + @Test + public void test_chat_allConditionsMet_callsConfigMethods() { + // Given + aiInferenceService = new AIInferenceService(aiConfig, Optional.of(openAIClient)); + setupValidAIConfig(); + String systemPrompt = "You are a helpful assistant"; + String userPrompt = "Hello, world!"; + + // When - This will throw an exception due to actual OpenAI call, but we can verify config calls + try { + aiInferenceService.chat(systemPrompt, userPrompt); + } catch (Exception e) { + } + + // Then - Verify that config methods were called + Mockito.verify(aiConfig).isEnabled(); + Mockito.verify(aiConfig).isAIAvailable(); + } + + @Test + public void test_chat_aiNotEnabled_throwsException() { + // Given + aiInferenceService = new AIInferenceService(aiConfig, Optional.of(openAIClient)); + Mockito.when(aiConfig.isEnabled()).thenReturn(false); + thrown.expect(BadRequestException.class); + thrown.expectMessage("AI service is not enabled"); + + // When + aiInferenceService.chat("system", "user"); + } + + @Test + public void test_chat_aiNotAvailable_throwsException() { + // Given + aiInferenceService = new AIInferenceService(aiConfig, Optional.of(openAIClient)); + Mockito.when(aiConfig.isEnabled()).thenReturn(true); + Mockito.when(aiConfig.isAIAvailable()).thenReturn(false); + thrown.expect(BadRequestException.class); + thrown.expectMessage("AI configuration is incomplete"); + + // When + aiInferenceService.chat("system", "user"); + } + + @Test + public void test_chat_clientNotPresent_throwsException() { + // Given + aiInferenceService = new AIInferenceService(aiConfig, Optional.empty()); + Mockito.when(aiConfig.isEnabled()).thenReturn(true); + Mockito.when(aiConfig.isAIAvailable()).thenReturn(true); + thrown.expect(BadRequestException.class); + thrown.expectMessage("AI client is not initialized"); + + // When + aiInferenceService.chat("system", "user"); + } + + @Test + public void test_chat_clientThrowsException_wrapsException() { + // Given + aiInferenceService = new AIInferenceService(aiConfig, Optional.of(openAIClient)); + setupValidAIConfig(); + // Note: This test verifies exception wrapping behavior + // The actual OpenAI client will throw an exception due to invalid configuration + thrown.expect(BadRequestException.class); + thrown.expectMessage("Failed to call AI inference service"); + + // When + aiInferenceService.chat("system", "user"); + } + + @Test + public void test_chat_verifyParametersPassedCorrectly() { + // Given + aiInferenceService = new AIInferenceService(aiConfig, Optional.of(openAIClient)); + setupValidAIConfig(); + String systemPrompt = "You are a helpful assistant"; + String userPrompt = "Hello, world!"; + String model = "gpt-3.5-turbo"; + Double temperature = 0.7; + Double topP = 0.9; + Map additionalParams = new HashMap<>(); + + Mockito.when(aiConfig.getModel()).thenReturn(model); + Mockito.when(aiConfig.getTemperature()).thenReturn(temperature); + Mockito.when(aiConfig.getTopP()).thenReturn(topP); + Mockito.when(aiConfig.loadAdditionalParams()).thenReturn(additionalParams); + + // When - This will throw an exception but we can verify config method calls + try { + aiInferenceService.chat(systemPrompt, userPrompt); + } catch (Exception e) { + // Expected - actual OpenAI call will fail + } + + // Then - Verify config methods were called + Mockito.verify(aiConfig).getModel(); + Mockito.verify(aiConfig).getTemperature(); + Mockito.verify(aiConfig).getTopP(); + Mockito.verify(aiConfig).loadAdditionalParams(); + } + + @Test + public void test_isAIAvailable_allConditionsMet_returnsTrue() { + // Given + aiInferenceService = new AIInferenceService(aiConfig, Optional.of(openAIClient)); + Mockito.when(aiConfig.isEnabled()).thenReturn(true); + Mockito.when(aiConfig.isAIAvailable()).thenReturn(true); + + // When + boolean result = aiInferenceService.isAIAvailable(); + + // Then + Assert.assertTrue(result); + } + + @Test + public void test_isAIAvailable_aiNotEnabled_returnsFalse() { + // Given + aiInferenceService = new AIInferenceService(aiConfig, Optional.of(openAIClient)); + Mockito.when(aiConfig.isEnabled()).thenReturn(false); + // Note: aiConfig.isAIAvailable() stubbing removed as it's not called due to short-circuit + // evaluation + + // When + boolean result = aiInferenceService.isAIAvailable(); + + // Then + Assert.assertFalse(result); + } + + @Test + public void test_isAIAvailable_aiNotAvailable_returnsFalse() { + // Given + aiInferenceService = new AIInferenceService(aiConfig, Optional.of(openAIClient)); + Mockito.when(aiConfig.isEnabled()).thenReturn(true); + Mockito.when(aiConfig.isAIAvailable()).thenReturn(false); + + // When + boolean result = aiInferenceService.isAIAvailable(); + + // Then + Assert.assertFalse(result); + } + + @Test + public void test_isAIAvailable_clientNotPresent_returnsFalse() { + // Given + aiInferenceService = new AIInferenceService(aiConfig, Optional.empty()); + Mockito.when(aiConfig.isEnabled()).thenReturn(true); + Mockito.when(aiConfig.isAIAvailable()).thenReturn(true); + + // When + boolean result = aiInferenceService.isAIAvailable(); + + // Then + Assert.assertFalse(result); + } + + @Test + public void test_isAIAvailable_noConditionsMet_returnsFalse() { + // Given + aiInferenceService = new AIInferenceService(aiConfig, Optional.empty()); + Mockito.when(aiConfig.isEnabled()).thenReturn(false); + // Note: aiConfig.isAIAvailable() stubbing removed as it's not called due to short-circuit + // evaluation + + // When + boolean result = aiInferenceService.isAIAvailable(); + + // Then + Assert.assertFalse(result); + } + + @Test + public void test_constructor_withValidParameters_createsInstance() { + // Given & When + AIInferenceService service = new AIInferenceService(aiConfig, Optional.of(openAIClient)); + + // Then + Assert.assertNotNull(service); + } + + @Test + public void test_constructor_withEmptyClient_createsInstance() { + // Given & When + AIInferenceService service = new AIInferenceService(aiConfig, Optional.empty()); + + // Then + Assert.assertNotNull(service); + } + + private void setupValidAIConfig() { + Mockito.when(aiConfig.isEnabled()).thenReturn(true); + Mockito.when(aiConfig.isAIAvailable()).thenReturn(true); + Mockito.when(aiConfig.getModel()).thenReturn("gpt-3.5-turbo"); + Mockito.when(aiConfig.getTemperature()).thenReturn(0.7); + Mockito.when(aiConfig.getTopP()).thenReturn(0.9); + Mockito.when(aiConfig.loadAdditionalParams()).thenReturn(new HashMap<>()); + } +} diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/ai/PromptTemplateLoaderTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/ai/PromptTemplateLoaderTest.java new file mode 100644 index 0000000000..e663f3f679 --- /dev/null +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/ai/PromptTemplateLoaderTest.java @@ -0,0 +1,221 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.ai; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.test.util.ReflectionTestUtils; + +@RunWith(MockitoJUnitRunner.class) +public class PromptTemplateLoaderTest { + + private PromptTemplateLoader promptTemplateLoader; + private static final String MOCK_TEMPLATE = "Sensitive types: {sensitiveTypes}\nCustom prompt: {customPrompt}"; + + @Before + public void setUp() { + promptTemplateLoader = new PromptTemplateLoader(); + // Set mock template to avoid file loading issues in test + ReflectionTestUtils.setField(promptTemplateLoader, "systemTemplate", MOCK_TEMPLATE); + } + + @Test + public void test_buildSystemPrompt_withValidSensitiveTypesAndCustomPrompt_returnsFormattedPrompt() { + // Given + List sensitiveTypes = Arrays.asList("contact_info", "identity_info"); + String customPrompt = "Additional rules for identification"; + + // When + String result = promptTemplateLoader.buildSystemPrompt(sensitiveTypes, customPrompt); + + // Then + Assert.assertNotNull("Result should not be null", result); + Assert.assertTrue("Should contain formatted sensitive types", + result.contains("contact_info, identity_info")); + Assert.assertTrue("Should contain custom prompt", + result.contains("Additional rules for identification")); + } + + @Test + public void test_buildSystemPrompt_withEmptySensitiveTypes_returnsDefaultMessage() { + // Given + List sensitiveTypes = Collections.emptyList(); + String customPrompt = "Custom rules"; + + // When + String result = promptTemplateLoader.buildSystemPrompt(sensitiveTypes, customPrompt); + + // Then + Assert.assertNotNull("Result should not be null", result); + Assert.assertTrue("Should contain default message for empty types", + result.contains("No specified category.")); + Assert.assertTrue("Should contain custom prompt", result.contains("Custom rules")); + } + + @Test + public void test_buildSystemPrompt_withNullSensitiveTypes_returnsDefaultMessage() { + // Given + List sensitiveTypes = null; + String customPrompt = "Custom rules"; + + // When + String result = promptTemplateLoader.buildSystemPrompt(sensitiveTypes, customPrompt); + + // Then + Assert.assertNotNull("Result should not be null", result); + Assert.assertTrue("Should contain default message for null types", + result.contains("No specified category.")); + Assert.assertTrue("Should contain custom prompt", result.contains("Custom rules")); + } + + @Test + public void test_buildSystemPrompt_withEmptyCustomPrompt_returnsDefaultMessage() { + // Given + List sensitiveTypes = Arrays.asList("email", "phone"); + String customPrompt = ""; + + // When + String result = promptTemplateLoader.buildSystemPrompt(sensitiveTypes, customPrompt); + + // Then + Assert.assertNotNull("Result should not be null", result); + Assert.assertTrue("Should contain formatted sensitive types", + result.contains("email, phone")); + Assert.assertTrue("Should contain default message for empty prompt", + result.contains("No supplementary rule.")); + } + + @Test + public void test_buildSystemPrompt_withNullCustomPrompt_returnsDefaultMessage() { + // Given + List sensitiveTypes = Arrays.asList("address"); + String customPrompt = null; + + // When + String result = promptTemplateLoader.buildSystemPrompt(sensitiveTypes, customPrompt); + + // Then + Assert.assertNotNull("Result should not be null", result); + Assert.assertTrue("Should contain formatted sensitive types", result.contains("address")); + Assert.assertTrue("Should contain default message for null prompt", + result.contains("No supplementary rule.")); + } + + @Test + public void test_buildSystemPrompt_withWhitespaceCustomPrompt_returnsDefaultMessage() { + // Given + List sensitiveTypes = Arrays.asList("name"); + String customPrompt = " \t\n "; + + // When + String result = promptTemplateLoader.buildSystemPrompt(sensitiveTypes, customPrompt); + + // Then + Assert.assertNotNull("Result should not be null", result); + Assert.assertTrue("Should contain formatted sensitive types", result.contains("name")); + Assert.assertTrue("Should contain default message for whitespace prompt", + result.contains("No supplementary rule.")); + } + + @Test + public void test_buildSystemPrompt_withSingleSensitiveType_returnsCorrectFormat() { + // Given + List sensitiveTypes = Arrays.asList("credit_card"); + String customPrompt = "Strict validation required"; + + // When + String result = promptTemplateLoader.buildSystemPrompt(sensitiveTypes, customPrompt); + + // Then + Assert.assertNotNull("Result should not be null", result); + Assert.assertTrue("Should contain single sensitive type", result.contains("credit_card")); + Assert.assertFalse("Should not contain comma for single type", + result.contains("credit_card,")); + Assert.assertTrue("Should contain custom prompt", + result.contains("Strict validation required")); + } + + @Test(expected = IllegalStateException.class) + public void test_buildSystemPrompt_withNullTemplate_throwsException() { + // Given + ReflectionTestUtils.setField(promptTemplateLoader, "systemTemplate", null); + List sensitiveTypes = Arrays.asList("test"); + String customPrompt = "test"; + + // When + promptTemplateLoader.buildSystemPrompt(sensitiveTypes, customPrompt); + + // Then - exception should be thrown + } + + @Test(expected = IllegalStateException.class) + public void test_buildSystemPrompt_withEmptyTemplate_throwsException() { + // Given + ReflectionTestUtils.setField(promptTemplateLoader, "systemTemplate", ""); + List sensitiveTypes = Arrays.asList("test"); + String customPrompt = "test"; + + // When + promptTemplateLoader.buildSystemPrompt(sensitiveTypes, customPrompt); + + // Then - exception should be thrown + } + + @Test + public void test_buildSystemPrompt_withMultipleSensitiveTypes_returnsCommaSeparated() { + // Given + List sensitiveTypes = Arrays.asList("email", "phone", "address", "name"); + String customPrompt = "Multiple type validation"; + + // When + String result = promptTemplateLoader.buildSystemPrompt(sensitiveTypes, customPrompt); + + // Then + Assert.assertNotNull("Result should not be null", result); + Assert.assertTrue("Should contain all types comma-separated", + result.contains("email, phone, address, name")); + Assert.assertTrue("Should contain custom prompt", + result.contains("Multiple type validation")); + } + + @Test + public void test_buildSystemPrompt_preservesOriginalTemplate_afterMultipleCalls() { + // Given + List sensitiveTypes1 = Arrays.asList("type1"); + List sensitiveTypes2 = Arrays.asList("type2"); + String customPrompt1 = "prompt1"; + String customPrompt2 = "prompt2"; + + // When + String result1 = promptTemplateLoader.buildSystemPrompt(sensitiveTypes1, customPrompt1); + String result2 = promptTemplateLoader.buildSystemPrompt(sensitiveTypes2, customPrompt2); + + // Then + Assert.assertNotNull("First result should not be null", result1); + Assert.assertNotNull("Second result should not be null", result2); + Assert.assertTrue("First result should contain type1", result1.contains("type1")); + Assert.assertTrue("Second result should contain type2", result2.contains("type2")); + Assert.assertFalse("First result should not contain type2", result1.contains("type2")); + Assert.assertFalse("Second result should not contain type1", result2.contains("type1")); + } +} diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/factory/ScanningStrategyFactoryTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/factory/ScanningStrategyFactoryTest.java new file mode 100644 index 0000000000..3a2c3bdc3d --- /dev/null +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/factory/ScanningStrategyFactoryTest.java @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.factory; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +import com.oceanbase.odc.service.datasecurity.model.ScanResult; +import com.oceanbase.odc.service.datasecurity.model.ScanningModeType; +import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; +import com.oceanbase.odc.service.datasecurity.strategy.AIOnlyStrategy; +import com.oceanbase.odc.service.datasecurity.strategy.JointRecognitionStrategy; +import com.oceanbase.odc.service.datasecurity.strategy.RulesOnlyStrategy; +import com.oceanbase.odc.service.datasecurity.strategy.ScanningStrategy; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +@RunWith(MockitoJUnitRunner.class) +public class ScanningStrategyFactoryTest { + + private ScanningStrategyFactory factory; + private DBTableColumn testColumn; + private List emptyRecognizers; + + @Before + public void setUp() { + factory = new ScanningStrategyFactory(); + testColumn = createTestColumn("user_phone", "varchar", "user phone number"); + emptyRecognizers = Collections.emptyList(); + } + + @Test + public void test_getStrategy_withRulesOnly_returnsRulesOnlyStrategy() { + // When + ScanningStrategy strategy = factory.getStrategy(ScanningModeType.RULES_ONLY); + + // Then + Assert.assertNotNull("Should return a strategy", strategy); + Assert.assertTrue("Should return RulesOnlyStrategy", strategy instanceof RulesOnlyStrategy); + } + + @Test + public void test_getStrategy_withAiOnly_returnsAiOnlyStrategy() { + // When + ScanningStrategy strategy = factory.getStrategy(ScanningModeType.AI_ONLY); + + // Then + Assert.assertNotNull("Should return a strategy", strategy); + Assert.assertTrue("Should return AIOnlyStrategy", strategy instanceof AIOnlyStrategy); + } + + @Test + public void test_getStrategy_withJointRecognition_returnsJointRecognitionStrategy() { + // When + ScanningStrategy strategy = factory.getStrategy(ScanningModeType.JOINT_RECOGNITION); + + // Then + Assert.assertNotNull("Should return a strategy", strategy); + Assert.assertTrue("Should return JointRecognitionStrategy", strategy instanceof JointRecognitionStrategy); + } + + @Test + public void test_getStrategy_withNullMode_returnsNoOpStrategy() { + // When + ScanningStrategy strategy = factory.getStrategy(null); + + // Then + Assert.assertNotNull("Should return a strategy", strategy); + + // Test that it behaves like NoOpStrategy + ScanResult result = strategy.scan(testColumn, emptyRecognizers, emptyRecognizers); + Assert.assertFalse("NoOp strategy should not have basic result", result.getBasicRuleResult().isPresent()); + Assert.assertFalse("NoOp strategy should not have AI result", result.getAiRuleResult().isPresent()); + } + + @Test + public void test_getStrategy_returnsSameInstanceForSameMode() { + // When + ScanningStrategy strategy1 = factory.getStrategy(ScanningModeType.RULES_ONLY); + ScanningStrategy strategy2 = factory.getStrategy(ScanningModeType.RULES_ONLY); + + // Then + Assert.assertSame("Should return same instance for same mode", strategy1, strategy2); + } + + @Test + public void test_getStrategy_returnsDifferentInstancesForDifferentModes() { + // When + ScanningStrategy rulesOnlyStrategy = factory.getStrategy(ScanningModeType.RULES_ONLY); + ScanningStrategy aiOnlyStrategy = factory.getStrategy(ScanningModeType.AI_ONLY); + ScanningStrategy jointStrategy = factory.getStrategy(ScanningModeType.JOINT_RECOGNITION); + + // Then + Assert.assertNotSame("Rules and AI strategies should be different", rulesOnlyStrategy, aiOnlyStrategy); + Assert.assertNotSame("Rules and Joint strategies should be different", rulesOnlyStrategy, jointStrategy); + Assert.assertNotSame("AI and Joint strategies should be different", aiOnlyStrategy, jointStrategy); + } + + @Test + public void test_allStrategies_implementScanningStrategy() { + // Given + ScanningModeType[] allModes = + {ScanningModeType.RULES_ONLY, ScanningModeType.AI_ONLY, ScanningModeType.JOINT_RECOGNITION}; + + // When & Then + for (ScanningModeType mode : allModes) { + ScanningStrategy strategy = factory.getStrategy(mode); + Assert.assertNotNull("Strategy should not be null for mode: " + mode, strategy); + Assert.assertTrue("Strategy should implement ScanningStrategy for mode: " + mode, + strategy instanceof ScanningStrategy); + } + } + + @Test + public void test_allStrategies_canHandleBasicOperations() { + // Given + ScanningModeType[] allModes = + {ScanningModeType.RULES_ONLY, ScanningModeType.AI_ONLY, ScanningModeType.JOINT_RECOGNITION}; + List columns = Arrays.asList(testColumn); + + // When & Then + for (ScanningModeType mode : allModes) { + ScanningStrategy strategy = factory.getStrategy(mode); + + // Test single scan + ScanResult singleResult = strategy.scan(testColumn, emptyRecognizers, emptyRecognizers); + Assert.assertNotNull("Single scan result should not be null for mode: " + mode, singleResult); + + // Test batch scan + Map batchResults = strategy.scanBatch(columns, emptyRecognizers, emptyRecognizers); + Assert.assertNotNull("Batch scan results should not be null for mode: " + mode, batchResults); + Assert.assertEquals("Batch scan should return result for each column for mode: " + mode, + 1, batchResults.size()); + } + } + + @Test + public void test_noOpStrategy_handlesBatchScanCorrectly() { + // Given + ScanningStrategy noOpStrategy = factory.getStrategy(null); + DBTableColumn column1 = createTestColumn("col1", "varchar", "comment1"); + DBTableColumn column2 = createTestColumn("col2", "varchar", "comment2"); + List columns = Arrays.asList(column1, column2); + + // When + Map results = noOpStrategy.scanBatch(columns, emptyRecognizers, emptyRecognizers); + + // Then + Assert.assertEquals("Should return results for all columns", 2, results.size()); + + for (ScanResult result : results.values()) { + Assert.assertFalse("NoOp strategy should not have basic result", result.getBasicRuleResult().isPresent()); + Assert.assertFalse("NoOp strategy should not have AI result", result.getAiRuleResult().isPresent()); + } + } + + @Test + public void test_noOpStrategy_handlesNullColumnsGracefully() { + // Given + ScanningStrategy noOpStrategy = factory.getStrategy(null); + DBTableColumn columnWithNulls = new DBTableColumn(); + columnWithNulls.setName(null); + columnWithNulls.setSchemaName(null); + columnWithNulls.setTableName(null); + List columns = Arrays.asList(columnWithNulls); + + // When + Map results = noOpStrategy.scanBatch(columns, emptyRecognizers, emptyRecognizers); + + // Then + Assert.assertEquals("Should return results for all columns", 1, results.size()); + Assert.assertTrue("Should contain key for unknown column", + results.containsKey("unknown_schema.unknown_table.unknown_column")); + } + + private DBTableColumn createTestColumn(String columnName, String typeName, String comment) { + DBTableColumn column = new DBTableColumn(); + column.setName(columnName); + column.setTypeName(typeName); + column.setComment(comment); + column.setSchemaName("test_schema"); + column.setTableName("test_table"); + return column; + } +} diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/AIColumnRecognizerTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/AIColumnRecognizerTest.java new file mode 100644 index 0000000000..b72abe3592 --- /dev/null +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/AIColumnRecognizerTest.java @@ -0,0 +1,328 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.recognizer; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import com.oceanbase.odc.core.shared.exception.BadRequestException; +import com.oceanbase.odc.service.common.util.SpringContextUtil; +import com.oceanbase.odc.service.datasecurity.ai.AIInferenceService; +import com.oceanbase.odc.service.datasecurity.ai.PromptTemplateLoader; +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; +import com.openai.models.chat.completions.ChatCompletion; + +@RunWith(MockitoJUnitRunner.class) +public class AIColumnRecognizerTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Mock + private PromptTemplateLoader promptTemplateLoader; + + @Mock + private AIInferenceService aiInferenceService; + + private AIColumnRecognizer recognizer; + private SensitiveRule aiRule; + + @Before + public void setUp() { + aiRule = createAIRule(); + recognizer = new AIColumnRecognizer(aiRule); + } + + @Test + public void test_recognize_singleColumn_returnsSensitive() { + // Given + DBTableColumn column = createTestColumn("user_phone", "varchar", "user phone number"); + String systemPrompt = "System prompt for AI"; + String aiResponse = + "```json\n[{\"sensitive\": true, \"riskLevel\": \"HIGH\", \"sensitiveCategory\": \"contact_info\"}]\n```"; + + try (MockedStatic mockedSpringContext = Mockito.mockStatic(SpringContextUtil.class)) { + setupMocks(mockedSpringContext, systemPrompt, aiResponse); + + // When + Optional result = recognizer.recognize(column); + + // Then + Assert.assertTrue(result.isPresent()); + Assert.assertTrue(result.get().isMatched()); + Assert.assertEquals(SensitiveLevel.HIGH, result.get().getLevel()); + Assert.assertEquals("contact_info", result.get().getSensitiveType()); + Assert.assertEquals(SensitiveRuleType.AI, result.get().getSourceRuleType()); + } + } + + @Test + public void test_recognize_singleColumn_returnsNotSensitive() { + // Given + DBTableColumn column = createTestColumn("id", "bigint", "primary key"); + String systemPrompt = "System prompt for AI"; + String aiResponse = "```json\n[{\"sensitive\": false, \"riskLevel\": null, \"sensitiveCategory\": null}]\n```"; + + try (MockedStatic mockedSpringContext = Mockito.mockStatic(SpringContextUtil.class)) { + setupMocks(mockedSpringContext, systemPrompt, aiResponse); + + // When + Optional result = recognizer.recognize(column); + + // Then + Assert.assertFalse(result.isPresent()); + } + } + + @Test + public void test_recognizeBatch_multipleColumns_returnsMixedResults() { + // Given + List columns = Arrays.asList( + createTestColumn("user_phone", "varchar", "user phone number"), + createTestColumn("id", "bigint", "primary key"), + createTestColumn("email", "varchar", "user email address")); + String systemPrompt = "System prompt for AI"; + String aiResponse = "```json\n[" + + "{\"sensitive\": true, \"riskLevel\": \"HIGH\", \"sensitiveCategory\": \"contact_info\"}," + + "{\"sensitive\": false, \"riskLevel\": null, \"sensitiveCategory\": null}," + + "{\"sensitive\": true, \"riskLevel\": \"MEDIUM\", \"sensitiveCategory\": \"contact_info\"}" + + "]```"; + + try (MockedStatic mockedSpringContext = Mockito.mockStatic(SpringContextUtil.class)) { + setupMocks(mockedSpringContext, systemPrompt, aiResponse); + + // When + Map> results = recognizer.recognizeBatch(columns); + + // Then + Assert.assertEquals(3, results.size()); + + // Check phone column + String phoneKey = getColumnKey(columns.get(0)); + Assert.assertTrue(results.get(phoneKey).isPresent()); + Assert.assertEquals(SensitiveLevel.HIGH, results.get(phoneKey).get().getLevel()); + + // Check id column + String idKey = getColumnKey(columns.get(1)); + Assert.assertFalse(results.get(idKey).isPresent()); + + // Check email column + String emailKey = getColumnKey(columns.get(2)); + Assert.assertTrue(results.get(emailKey).isPresent()); + Assert.assertEquals(SensitiveLevel.MEDIUM, results.get(emailKey).get().getLevel()); + } + } + + @Test + public void test_recognizeBatch_emptyList_returnsEmptyMap() { + // When + Map> results = recognizer.recognizeBatch(Collections.emptyList()); + + // Then + Assert.assertTrue(results.isEmpty()); + } + + @Test + public void test_recognizeBatch_nullList_returnsEmptyMap() { + // When + Map> results = recognizer.recognizeBatch(null); + + // Then + Assert.assertTrue(results.isEmpty()); + } + + @Test + public void test_recognizeBatch_aiServiceThrowsException_returnsEmptyResults() { + // Given + List columns = Arrays.asList(createTestColumn("test", "varchar", "test")); + String systemPrompt = "System prompt for AI"; + + try (MockedStatic mockedSpringContext = Mockito.mockStatic(SpringContextUtil.class)) { + mockedSpringContext.when(() -> SpringContextUtil.getBean(PromptTemplateLoader.class)) + .thenReturn(promptTemplateLoader); + mockedSpringContext.when(() -> SpringContextUtil.getBean(AIInferenceService.class)) + .thenReturn(aiInferenceService); + + Mockito.when(promptTemplateLoader.buildSystemPrompt(Mockito.anyList(), Mockito.anyString())) + .thenReturn(systemPrompt); + Mockito.when(aiInferenceService.chat(Mockito.anyString(), Mockito.anyString())) + .thenThrow(new RuntimeException("AI service error")); + + // When + Map> results = recognizer.recognizeBatch(columns); + + // Then + Assert.assertTrue(results.isEmpty()); + } + } + + @Test + public void test_recognizeBatch_invalidJsonResponse_throwsException() { + // Given + List columns = Arrays.asList(createTestColumn("test", "varchar", "test")); + String systemPrompt = "System prompt for AI"; + String invalidResponse = "This is not a valid JSON response"; + + try (MockedStatic mockedSpringContext = Mockito.mockStatic(SpringContextUtil.class)) { + setupMocks(mockedSpringContext, systemPrompt, invalidResponse); + + thrown.expect(BadRequestException.class); + thrown.expectMessage("AI response does not contain valid JSON format"); + + // When + recognizer.recognizeBatch(columns); + } + } + + @Test + public void test_recognizeBatch_malformedJsonResponse_throwsException() { + // Given + List columns = Arrays.asList(createTestColumn("test", "varchar", "test")); + String systemPrompt = "System prompt for AI"; + String malformedResponse = + "```json\n[{\"sensitive\": true, \"riskLevel\": \"HIGH\" missing_comma \"field\": \"value\"}]```"; // Invalid + // JSON + // syntax + + try (MockedStatic mockedSpringContext = Mockito.mockStatic(SpringContextUtil.class)) { + setupMocks(mockedSpringContext, systemPrompt, malformedResponse); + + thrown.expect(BadRequestException.class); + thrown.expectMessage("Failed to parse AI response JSON"); + + // When + recognizer.recognizeBatch(columns); + } + } + + @Test + public void test_recognizeBatch_responseWithoutJsonWrapper_parsesCorrectly() { + // Given + DBTableColumn column = createTestColumn("user_phone", "varchar", "user phone number"); + String systemPrompt = "System prompt for AI"; + String aiResponse = "[{\"sensitive\": true, \"riskLevel\": \"HIGH\", \"sensitiveCategory\": \"contact_info\"}]"; // No + // ```json + // wrapper + + try (MockedStatic mockedSpringContext = Mockito.mockStatic(SpringContextUtil.class)) { + setupMocks(mockedSpringContext, systemPrompt, aiResponse); + + // When + Optional result = recognizer.recognize(column); + + // Then + Assert.assertTrue(result.isPresent()); + Assert.assertTrue(result.get().isMatched()); + Assert.assertEquals(SensitiveLevel.HIGH, result.get().getLevel()); + } + } + + @Test + public void test_recognizeBatch_mismatchedResponseCount_handlesGracefully() { + // Given + List columns = Arrays.asList( + createTestColumn("col1", "varchar", "test1"), + createTestColumn("col2", "varchar", "test2"), + createTestColumn("col3", "varchar", "test3")); + String systemPrompt = "System prompt for AI"; + // AI returns only 2 results for 3 columns + String aiResponse = "```json\n[" + + "{\"sensitive\": true, \"riskLevel\": \"HIGH\", \"sensitiveCategory\": \"contact_info\"}," + + "{\"sensitive\": false, \"riskLevel\": null, \"sensitiveCategory\": null}" + + "]```"; + + try (MockedStatic mockedSpringContext = Mockito.mockStatic(SpringContextUtil.class)) { + setupMocks(mockedSpringContext, systemPrompt, aiResponse); + + // When + Map> results = recognizer.recognizeBatch(columns); + + // Then + Assert.assertEquals(2, results.size()); // Only 2 results processed + + String col1Key = getColumnKey(columns.get(0)); + String col2Key = getColumnKey(columns.get(1)); + String col3Key = getColumnKey(columns.get(2)); + + Assert.assertTrue(results.containsKey(col1Key)); + Assert.assertTrue(results.containsKey(col2Key)); + Assert.assertFalse(results.containsKey(col3Key)); // Third column not processed + } + } + + // Helper methods + private void setupMocks(MockedStatic mockedSpringContext, String systemPrompt, + String aiResponse) { + mockedSpringContext.when(() -> SpringContextUtil.getBean(PromptTemplateLoader.class)) + .thenReturn(promptTemplateLoader); + mockedSpringContext.when(() -> SpringContextUtil.getBean(AIInferenceService.class)) + .thenReturn(aiInferenceService); + + Mockito.when(promptTemplateLoader.buildSystemPrompt(Mockito.anyList(), Mockito.anyString())) + .thenReturn(systemPrompt); + + // Mock the chain: completion.choices().get(0).message().content().orElse("[]") + // Use Mockito's deep stubbing with RETURNS_DEEP_STUBS + ChatCompletion mockCompletion = Mockito.mock(ChatCompletion.class, Mockito.RETURNS_DEEP_STUBS); + Mockito.when(aiInferenceService.chat(Mockito.anyString(), Mockito.anyString())) + .thenReturn(mockCompletion); + Mockito.when(mockCompletion.choices().get(0).message().content()) + .thenReturn(Optional.of(aiResponse)); + } + + private DBTableColumn createTestColumn(String columnName, String typeName, String comment) { + DBTableColumn column = new DBTableColumn(); + column.setSchemaName("test_schema"); + column.setTableName("test_table"); + column.setName(columnName); + column.setTypeName(typeName); + column.setComment(comment); + return column; + } + + private SensitiveRule createAIRule() { + SensitiveRule rule = new SensitiveRule(); + rule.setId(1L); + rule.setType(SensitiveRuleType.AI); + rule.setLevel(SensitiveLevel.HIGH); + rule.setEnabled(true); + rule.setAiSensitiveTypes(Arrays.asList("contact_info", "identity_info")); + rule.setAiCustomPrompt("Custom AI prompt for testing"); + return rule; + } + + private String getColumnKey(DBTableColumn column) { + return column.getSchemaName() + "." + column.getTableName() + "." + column.getName(); + } +} diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/GroovyColumnRecognizerTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/GroovyColumnRecognizerTest.java index a936364325..c885dc682e 100644 --- a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/GroovyColumnRecognizerTest.java +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/GroovyColumnRecognizerTest.java @@ -15,44 +15,58 @@ */ package com.oceanbase.odc.service.datasecurity.recognizer; +import java.util.Optional; + import org.codehaus.groovy.control.MultipleCompilationErrorsException; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; import com.oceanbase.tools.dbbrowser.model.DBTableColumn; -/** - * @author gaoda.xy - * @date 2023/5/23 19:35 - */ public class GroovyColumnRecognizerTest { + @Rule public ExpectedException thrown = ExpectedException.none(); @Test public void test_recognize_true() { - ColumnRecognizer recognizer = new GroovyColumnRecognizer(buildGroovyScript()); - DBTableColumn dbTableColumn = createDBTableColumn(); - Assert.assertTrue(recognizer.recognize(dbTableColumn)); + SensitiveRule rule = createGroovyRule(1L, buildDefaultGroovyScript()); + ColumnRecognizer recognizer = new GroovyColumnRecognizer(rule); + DBTableColumn dbTableColumn = createTestColumn(); + Optional resultOpt = recognizer.recognize(dbTableColumn); + Assert.assertTrue("Script matching is successful. An Optional with a value should be returned.", + resultOpt.isPresent()); + RecognitionResult result = resultOpt.get(); + Assert.assertEquals("The matching rule ID should be 1.", rule.getId(), result.getMatchedRuleId()); + Assert.assertEquals("The rule type should be GROOVY", SensitiveRuleType.GROOVY, result.getSourceRuleType()); } @Test public void test_recognize_false() { - ColumnRecognizer recognizer = new GroovyColumnRecognizer(buildGroovyScript()); - DBTableColumn dbTableColumn = createDBTableColumn(); - dbTableColumn.setTableName("unmatched"); - Assert.assertFalse(recognizer.recognize(dbTableColumn)); + SensitiveRule rule = createGroovyRule(1L, buildDefaultGroovyScript()); + ColumnRecognizer recognizer = new GroovyColumnRecognizer(rule); + DBTableColumn dbTableColumn = createTestColumn(); + dbTableColumn.setTableName("unmatched_table"); + Optional resultOpt = recognizer.recognize(dbTableColumn); + Assert.assertFalse("Script matching failed. It should return an empty Optional.", resultOpt.isPresent()); } @Test public void test_recognize_nullColumnName() { - ColumnRecognizer recognizer = new GroovyColumnRecognizer(buildGroovyScript()); - DBTableColumn dbTableColumn = createDBTableColumn(); - dbTableColumn.setTableName(null); - Assert.assertFalse(recognizer.recognize(dbTableColumn)); + SensitiveRule rule = createGroovyRule(1L, buildDefaultGroovyScript()); + ColumnRecognizer recognizer = new GroovyColumnRecognizer(rule); + DBTableColumn dbTableColumn = createTestColumn(); + dbTableColumn.setName(null); + Optional resultOpt = recognizer.recognize(dbTableColumn); + Assert.assertFalse("The script execution has failed. It should return an empty Optional.", + resultOpt.isPresent()); } @Test @@ -60,7 +74,7 @@ public void test_securityInterceptor_systemExit() { thrown.expect(Exception.class); thrown.expectMessage("Method call is not security"); String script = "System.exit(-1);"; - new GroovyColumnRecognizer(script); + new GroovyColumnRecognizer(createGroovyRule(1L, script)); } @Test @@ -70,7 +84,7 @@ public void test_securityInterceptor_forLoop() { String script = "for (int i = 0; i < 1; i++) {\n" + " i = 0;\n" + "}"; - new GroovyColumnRecognizer(script); + new GroovyColumnRecognizer(createGroovyRule(1L, script)); } @Test @@ -80,28 +94,30 @@ public void test_securityInterceptor_whileLoop() { String script = "while(true) {\n" + " int i = 0;\n" + "}"; - new GroovyColumnRecognizer(script); + new GroovyColumnRecognizer(createGroovyRule(1L, script)); } @Test - public void test_securityInterceptor_threadSleep() throws InterruptedException { + public void test_securityInterceptor_threadSleep() { thrown.expect(MultipleCompilationErrorsException.class); thrown.expectMessage("java.lang.Thread"); String script = "Thread.sleep(1000);"; - new GroovyColumnRecognizer(script); + new GroovyColumnRecognizer(createGroovyRule(1L, script)); } @Test - public void test_securityInterceptor_importPackage() throws InterruptedException { + public void test_securityInterceptor_importPackage() { thrown.expect(MultipleCompilationErrorsException.class); thrown.expectMessage("java.lang.System"); String script = "import java.lang.System;"; - new GroovyColumnRecognizer(script); + new GroovyColumnRecognizer(createGroovyRule(1L, script)); } - private String buildGroovyScript() { + // --- 辅助方法 --- + + private String buildDefaultGroovyScript() { return "if (column.name.equals(\"column\")) {\n" - + " if (column.table.equalsIgnoreCase(\"IAM_USER\")) {\n" + + " if (column.table.equalsIgnoreCase(\"iam_user\")) {\n" + " if (column.schema.length() > 0) {\n" + " if (column.comment.indexOf(\"user\") > 0) {\n" + " if (column.type.toLowerCase().equals(\"varchar\")) {\n" @@ -114,7 +130,7 @@ private String buildGroovyScript() { + "return false;"; } - private DBTableColumn createDBTableColumn() { + private DBTableColumn createTestColumn() { DBTableColumn dbTableColumn = new DBTableColumn(); dbTableColumn.setSchemaName("odc_meta"); dbTableColumn.setTableName("iam_user"); @@ -124,4 +140,13 @@ private DBTableColumn createDBTableColumn() { return dbTableColumn; } + private SensitiveRule createGroovyRule(Long id, String script) { + SensitiveRule rule = new SensitiveRule(); + rule.setId(id); + rule.setType(SensitiveRuleType.GROOVY); + rule.setGroovyScript(script); + rule.setLevel(SensitiveLevel.HIGH); + rule.setEnabled(true); + return rule; + } } diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/PathColumnRecognizerTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/PathColumnRecognizerTest.java index 8c3f19aa2a..da12ab3a42 100644 --- a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/PathColumnRecognizerTest.java +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/PathColumnRecognizerTest.java @@ -16,34 +16,52 @@ package com.oceanbase.odc.service.datasecurity.recognizer; import java.util.Arrays; +import java.util.List; +import java.util.Optional; import org.junit.Assert; import org.junit.Test; +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; import com.oceanbase.tools.dbbrowser.model.DBTableColumn; -/** - * @author gaoda.xy - * @date 2023/5/24 15:05 - */ public class PathColumnRecognizerTest { @Test public void test_recognize_true() { - ColumnRecognizer recognizer = new PathColumnRecognizer(Arrays.asList("*.*b*.c"), Arrays.asList("a.b.*")); - Assert.assertTrue(recognizer.recognize(createDBTableColumn("a", "b12", "c"))); - Assert.assertTrue(recognizer.recognize(createDBTableColumn("a12", "34b56", "c"))); - Assert.assertTrue(recognizer.recognize(createDBTableColumn("a12", "34b", "c"))); + SensitiveRule rule = createPathRule(1L, Arrays.asList("*.*b*.c"), Arrays.asList("a.b.*")); + ColumnRecognizer recognizer = new PathColumnRecognizer(rule); + + Optional result1 = recognizer.recognize(createDBTableColumn("a", "b12", "c")); + Assert.assertTrue("Path 'a.b12.c' should match successfully", result1.isPresent()); + Assert.assertEquals(rule.getId(), result1.get().getMatchedRuleId()); + + Optional result2 = recognizer.recognize(createDBTableColumn("a12", "34b56", "c")); + Assert.assertTrue("Path 'a12.34b56.c' should match successfully", result2.isPresent()); + + Optional result3 = recognizer.recognize(createDBTableColumn("a12", "34b", "c")); + Assert.assertTrue("Path 'a12.34b.c' should match successfully", result3.isPresent()); } @Test public void test_recognize_false() { - ColumnRecognizer recognizer = new PathColumnRecognizer(Arrays.asList("*.*b*.c"), Arrays.asList("a.b.*")); - Assert.assertFalse(recognizer.recognize(createDBTableColumn("a", "b", "c"))); - Assert.assertFalse(recognizer.recognize(createDBTableColumn("a12", "b34", "c56"))); - Assert.assertFalse(recognizer.recognize(createDBTableColumn("a12", "b34", null))); + SensitiveRule rule = createPathRule(1L, Arrays.asList("*.*b*.c"), Arrays.asList("a.b.*")); + ColumnRecognizer recognizer = new PathColumnRecognizer(rule); + + Assert.assertFalse("The path 'a.b.c' should be excluded as the match failed.", + recognizer.recognize(createDBTableColumn("a", "b", "c")).isPresent()); + + Assert.assertFalse("Path 'a12.b34.c56' should not match; the match failed.", + recognizer.recognize(createDBTableColumn("a12", "b34", "c56")).isPresent()); + + Assert.assertFalse("Path 'a12.b34.null' should not match; the match failed.", + recognizer.recognize(createDBTableColumn("a12", "b34", null)).isPresent()); } + private DBTableColumn createDBTableColumn(String schemaName, String tableName, String columnName) { DBTableColumn column = new DBTableColumn(); column.setSchemaName(schemaName); @@ -51,4 +69,15 @@ private DBTableColumn createDBTableColumn(String schemaName, String tableName, S column.setName(columnName); return column; } + + private SensitiveRule createPathRule(Long id, List includes, List excludes) { + SensitiveRule rule = new SensitiveRule(); + rule.setId(id); + rule.setType(SensitiveRuleType.PATH); + rule.setPathIncludes(includes); + rule.setPathExcludes(excludes); + rule.setLevel(SensitiveLevel.HIGH); + rule.setEnabled(true); + return rule; + } } diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/RegexColumnRecognizerTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/RegexColumnRecognizerTest.java index 524b2acd91..93f312d1a2 100644 --- a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/RegexColumnRecognizerTest.java +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/recognizer/RegexColumnRecognizerTest.java @@ -15,33 +15,58 @@ */ package com.oceanbase.odc.service.datasecurity.recognizer; +import java.util.Optional; + import org.junit.Assert; import org.junit.Test; +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRule; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; import com.oceanbase.tools.dbbrowser.model.DBTableColumn; -/** - * @author gaoda.xy - * @date 2023/5/24 16:21 - */ public class RegexColumnRecognizerTest { @Test public void recognize_returnTrue() { - ColumnRecognizer recognizer = createRegexColumnRecognizer(); - Assert.assertTrue(recognizer.recognize(createDBTableColumn("xxx", "xxx", "user_email", "email of user"))); + SensitiveRule rule = createTestRegexRule(1L); + ColumnRecognizer recognizer = new RegexColumnRecognizer(rule); + DBTableColumn column = createDBTableColumn("xxx", "xxx", "user_email", "email of user"); + + Optional resultOpt = recognizer.recognize(column); + + Assert.assertTrue("The regular expression should match successfully.", resultOpt.isPresent()); + Assert.assertEquals("The matching rule ID should be 1.", rule.getId(), resultOpt.get().getMatchedRuleId()); } @Test public void recognize_returnFalse() { - ColumnRecognizer recognizer = createRegexColumnRecognizer(); - Assert.assertFalse(recognizer.recognize(createDBTableColumn("xxx", "xxx", "user_email", null))); - Assert.assertFalse(recognizer.recognize(createDBTableColumn(" ", "xxx", "user_email", null))); - Assert.assertFalse(recognizer.recognize(createDBTableColumn("xxx", "xxx", "user", "email"))); + SensitiveRule rule = createTestRegexRule(1L); + ColumnRecognizer recognizer = new RegexColumnRecognizer(rule); + + Assert.assertFalse("When Comment is null, it should not match.", + recognizer.recognize(createDBTableColumn("xxx", "xxx", "user_email", null)).isPresent()); + + Assert.assertFalse("It should fail when the SchemaName does not match.", + recognizer.recognize(createDBTableColumn(" ", "xxx", "user_email", "email of user")).isPresent()); + + Assert.assertFalse("When both ColumnName and Comment do not match, it should fail.", + recognizer.recognize(createDBTableColumn("xxx", "xxx", "user", "some info")).isPresent()); } - private RegexColumnRecognizer createRegexColumnRecognizer() { - return new RegexColumnRecognizer("^\\S+$", "^\\S+$", "^\\S*email\\S*$", "^[\\S\\s]*email[\\S\\s]*$"); + + private SensitiveRule createTestRegexRule(Long id) { + SensitiveRule rule = new SensitiveRule(); + rule.setId(id); + rule.setType(SensitiveRuleType.REGEX); + rule.setEnabled(true); + rule.setLevel(SensitiveLevel.HIGH); + rule.setDatabaseRegexExpression("^\\S+$"); + rule.setTableRegexExpression("^\\S+$"); + rule.setColumnRegexExpression("^\\S*email\\S*$"); + rule.setColumnCommentRegexExpression("^[\\S\\s]*email[\\S\\s]*$"); + return rule; } private DBTableColumn createDBTableColumn(String schemaName, String tableName, String columnName, String comment) { diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/AIOnlyStrategyTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/AIOnlyStrategyTest.java new file mode 100644 index 0000000000..1f2d94ebad --- /dev/null +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/AIOnlyStrategyTest.java @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.strategy; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.ScanResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; +import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +@RunWith(MockitoJUnitRunner.class) +public class AIOnlyStrategyTest { + + @Mock + private ColumnRecognizer mockBasicRecognizer; + + @Mock + private ColumnRecognizer mockAiRecognizer; + + private AIOnlyStrategy strategy; + private List basicRecognizers; + private List aiRecognizers; + private DBTableColumn testColumn; + + @Before + public void setUp() { + strategy = new AIOnlyStrategy(); + basicRecognizers = Arrays.asList(mockBasicRecognizer); + aiRecognizers = Arrays.asList(mockAiRecognizer); + testColumn = createTestColumn("user_phone", "varchar", "user phone number"); + } + + @Test + public void test_scan_withAiRecognizerMatch_returnsAiResultOnly() { + // Given + RecognitionResult aiResult = createRecognitionResult(2L, SensitiveLevel.MEDIUM, SensitiveRuleType.AI); + + Mockito.when(mockAiRecognizer.recognize(testColumn)).thenReturn(Optional.of(aiResult)); + + // When + ScanResult result = strategy.scan(testColumn, basicRecognizers, aiRecognizers); + + // Then + Assert.assertFalse("Should not have basic rule result", result.getBasicRuleResult().isPresent()); + Assert.assertTrue("Should have AI rule result", result.getAiRuleResult().isPresent()); + Assert.assertEquals("Should return AI result", aiResult, result.getAiRuleResult().get()); + + // Verify basic recognizer is not called (AI only strategy) + Mockito.verify(mockBasicRecognizer, Mockito.never()).recognize(Mockito.any()); + } + + @Test + public void test_scan_withNoAiRecognizerMatch_returnsEmptyResult() { + // Given + Mockito.when(mockAiRecognizer.recognize(testColumn)).thenReturn(Optional.empty()); + + // When + ScanResult result = strategy.scan(testColumn, basicRecognizers, aiRecognizers); + + // Then + Assert.assertFalse("Should not have basic rule result", result.getBasicRuleResult().isPresent()); + Assert.assertFalse("Should not have AI rule result", result.getAiRuleResult().isPresent()); + + // Verify basic recognizer is not called + Mockito.verify(mockBasicRecognizer, Mockito.never()).recognize(Mockito.any()); + } + + @Test + public void test_scan_withEmptyAiRecognizers_returnsEmptyResult() { + // Given + List emptyAiRecognizers = Collections.emptyList(); + + // When + ScanResult result = strategy.scan(testColumn, basicRecognizers, emptyAiRecognizers); + + // Then + Assert.assertFalse("Should not have basic rule result", result.getBasicRuleResult().isPresent()); + Assert.assertFalse("Should not have AI rule result", result.getAiRuleResult().isPresent()); + } + + @Test + public void test_scanBatch_withMixedResults_returnsCorrectMapping() { + // Given + DBTableColumn column1 = createTestColumn("user_phone", "varchar", "phone"); + DBTableColumn column2 = createTestColumn("user_email", "varchar", "email"); + DBTableColumn column3 = createTestColumn("user_name", "varchar", "name"); + List columns = Arrays.asList(column1, column2, column3); + + RecognitionResult aiResult1 = createRecognitionResult(1L, SensitiveLevel.HIGH, SensitiveRuleType.AI); + RecognitionResult aiResult3 = createRecognitionResult(3L, SensitiveLevel.LOW, SensitiveRuleType.AI); + + // Mock recognizeBatch for single recognizer scenario + Map> batchResults = new HashMap<>(); + batchResults.put(getColumnKey(column1), Optional.of(aiResult1)); + batchResults.put(getColumnKey(column2), Optional.empty()); + batchResults.put(getColumnKey(column3), Optional.of(aiResult3)); + Mockito.when(mockAiRecognizer.recognizeBatch(columns)).thenReturn(batchResults); + + // When + Map results = strategy.scanBatch(columns, basicRecognizers, aiRecognizers); + + // Then + Assert.assertEquals("Should return results for all columns", 3, results.size()); + + String key1 = getColumnKey(column1); + String key2 = getColumnKey(column2); + String key3 = getColumnKey(column3); + + Assert.assertFalse("Column1 should not have basic result", results.get(key1).getBasicRuleResult().isPresent()); + Assert.assertTrue("Column1 should have AI result", results.get(key1).getAiRuleResult().isPresent()); + + Assert.assertFalse("Column2 should not have basic result", results.get(key2).getBasicRuleResult().isPresent()); + Assert.assertFalse("Column2 should not have AI result", results.get(key2).getAiRuleResult().isPresent()); + + Assert.assertFalse("Column3 should not have basic result", results.get(key3).getBasicRuleResult().isPresent()); + Assert.assertTrue("Column3 should have AI result", results.get(key3).getAiRuleResult().isPresent()); + + // Verify basic recognizer is never called + Mockito.verify(mockBasicRecognizer, Mockito.never()).recognize(Mockito.any()); + Mockito.verify(mockBasicRecognizer, Mockito.never()).recognizeBatch(Mockito.any()); + } + + @Test + public void test_scanBatch_withEmptyColumns_returnsEmptyMap() { + // Given + List emptyColumns = Collections.emptyList(); + + // When + Map results = strategy.scanBatch(emptyColumns, basicRecognizers, aiRecognizers); + + // Then + Assert.assertTrue("Should return empty map", results.isEmpty()); + } + + @Test + public void test_scanBatch_withEmptyAiRecognizers_returnsEmptyResults() { + // Given + List columns = Arrays.asList(testColumn); + List emptyAiRecognizers = Collections.emptyList(); + + // When + Map results = strategy.scanBatch(columns, basicRecognizers, emptyAiRecognizers); + + // Then + Assert.assertEquals("Should return results for all columns", 1, results.size()); + String key = getColumnKey(testColumn); + Assert.assertFalse("Should not have basic result", results.get(key).getBasicRuleResult().isPresent()); + Assert.assertFalse("Should not have AI result", results.get(key).getAiRuleResult().isPresent()); + } + + @Test + public void test_scanBatch_withSingleAiRecognizer_usesBatchRecognition() { + // Given + List columns = Arrays.asList(testColumn); + Map> batchResults = Collections.singletonMap( + getColumnKey(testColumn), + Optional.of(createRecognitionResult(1L, SensitiveLevel.HIGH, SensitiveRuleType.AI))); + + Mockito.when(mockAiRecognizer.recognizeBatch(columns)).thenReturn(batchResults); + + // When + Map results = strategy.scanBatch(columns, basicRecognizers, aiRecognizers); + + // Then + Assert.assertEquals("Should return results for all columns", 1, results.size()); + String key = getColumnKey(testColumn); + Assert.assertTrue("Should have AI result", results.get(key).getAiRuleResult().isPresent()); + + // Verify batch recognition is used + Mockito.verify(mockAiRecognizer).recognizeBatch(columns); + Mockito.verify(mockAiRecognizer, Mockito.never()).recognize(Mockito.any()); + } + + private DBTableColumn createTestColumn(String columnName, String typeName, String comment) { + DBTableColumn column = new DBTableColumn(); + column.setName(columnName); + column.setTypeName(typeName); + column.setComment(comment); + column.setSchemaName("test_schema"); + column.setTableName("test_table"); + return column; + } + + private RecognitionResult createRecognitionResult(Long ruleId, SensitiveLevel level, SensitiveRuleType ruleType) { + return RecognitionResult.builder() + .matched(true) + .matchedRuleId(ruleId) + .level(level) + .sourceRuleType(ruleType) + .build(); + } + + private String getColumnKey(DBTableColumn column) { + return String.format("%s.%s.%s", + column.getSchemaName() != null ? column.getSchemaName() : "unknown_schema", + column.getTableName() != null ? column.getTableName() : "unknown_table", + column.getName() != null ? column.getName() : "unknown_column"); + } +} diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/AbstractScanningStrategyTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/AbstractScanningStrategyTest.java new file mode 100644 index 0000000000..2e1317778c --- /dev/null +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/AbstractScanningStrategyTest.java @@ -0,0 +1,252 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.strategy; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; +import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +@RunWith(MockitoJUnitRunner.class) +public class AbstractScanningStrategyTest { + + @Mock + private ColumnRecognizer mockRecognizer1; + + @Mock + private ColumnRecognizer mockRecognizer2; + + private TestableAbstractScanningStrategy strategy; + private DBTableColumn testColumn; + + @Before + public void setUp() { + strategy = new TestableAbstractScanningStrategy(); + testColumn = createTestColumn("user_phone", "varchar", "user phone number"); + } + + @Test + public void test_findFirstMatch_withMatchingRecognizer_returnsResult() { + // Given + RecognitionResult expectedResult = createRecognitionResult(1L, SensitiveLevel.HIGH); + List recognizers = Arrays.asList(mockRecognizer1, mockRecognizer2); + + Mockito.when(mockRecognizer1.recognize(testColumn)).thenReturn(Optional.empty()); + Mockito.when(mockRecognizer2.recognize(testColumn)).thenReturn(Optional.of(expectedResult)); + + // When + Optional result = strategy.findFirstMatch(recognizers, testColumn); + + // Then + Assert.assertTrue("Should find matching result", result.isPresent()); + Assert.assertEquals("Should return expected result", expectedResult, result.get()); + } + + @Test + public void test_findFirstMatch_withNoMatchingRecognizer_returnsEmpty() { + // Given + List recognizers = Arrays.asList(mockRecognizer1, mockRecognizer2); + + Mockito.when(mockRecognizer1.recognize(testColumn)).thenReturn(Optional.empty()); + Mockito.when(mockRecognizer2.recognize(testColumn)).thenReturn(Optional.empty()); + + // When + Optional result = strategy.findFirstMatch(recognizers, testColumn); + + // Then + Assert.assertFalse("Should not find any result", result.isPresent()); + } + + @Test + public void test_findFirstMatch_withEmptyRecognizers_returnsEmpty() { + // Given + List recognizers = Collections.emptyList(); + + // When + Optional result = strategy.findFirstMatch(recognizers, testColumn); + + // Then + Assert.assertFalse("Should not find any result with empty recognizers", result.isPresent()); + } + + @Test + public void test_findAllFirstMatches_withMultipleColumns_returnsCorrectMapping() { + // Given + DBTableColumn column1 = createTestColumn("user_phone", "varchar", "phone"); + DBTableColumn column2 = createTestColumn("user_email", "varchar", "email"); + List columns = Arrays.asList(column1, column2); + List recognizers = Arrays.asList(mockRecognizer1); + + RecognitionResult result1 = createRecognitionResult(1L, SensitiveLevel.HIGH); + + // Mock recognizeBatch method for single recognizer scenario + Map> batchResults = new HashMap<>(); + batchResults.put(strategy.getColumnKey(column1), Optional.of(result1)); + batchResults.put(strategy.getColumnKey(column2), Optional.empty()); + Mockito.when(mockRecognizer1.recognizeBatch(columns)).thenReturn(batchResults); + + // When + Map> results = strategy.findAllFirstMatches(recognizers, columns); + + // Then + Assert.assertEquals("Should return results for all columns", 2, results.size()); + Assert.assertTrue("Should find result for column1", results.get(strategy.getColumnKey(column1)).isPresent()); + Assert.assertFalse("Should not find result for column2", + results.get(strategy.getColumnKey(column2)).isPresent()); + } + + @Test + public void test_findAllFirstMatches_withEmptyColumns_returnsEmptyMap() { + // Given + List columns = Collections.emptyList(); + List recognizers = Arrays.asList(mockRecognizer1); + + // When + Map> results = strategy.findAllFirstMatches(recognizers, columns); + + // Then + Assert.assertTrue("Should return empty map", results.isEmpty()); + } + + @Test + public void test_findAllFirstMatches_withEmptyRecognizers_returnsEmptyResults() { + // Given + List columns = Arrays.asList(testColumn); + List recognizers = Collections.emptyList(); + + // When + Map> results = strategy.findAllFirstMatches(recognizers, columns); + + // Then + Assert.assertEquals("Should return results for all columns", 1, results.size()); + Assert.assertFalse("Should not find any result", results.get(strategy.getColumnKey(testColumn)).isPresent()); + } + + @Test + public void test_findAllFirstMatches_withMultipleRecognizers_usesIndividualRecognize() { + // Given + DBTableColumn column1 = createTestColumn("user_phone", "varchar", "phone"); + DBTableColumn column2 = createTestColumn("user_email", "varchar", "email"); + List columns = Arrays.asList(column1, column2); + List recognizers = Arrays.asList(mockRecognizer1, mockRecognizer2); + + RecognitionResult result1 = createRecognitionResult(1L, SensitiveLevel.HIGH); + + // Mock individual recognize calls for multiple recognizers scenario + Mockito.when(mockRecognizer1.recognize(column1)).thenReturn(Optional.empty()); + Mockito.when(mockRecognizer2.recognize(column1)).thenReturn(Optional.of(result1)); + Mockito.when(mockRecognizer1.recognize(column2)).thenReturn(Optional.empty()); + Mockito.when(mockRecognizer2.recognize(column2)).thenReturn(Optional.empty()); + + // When + Map> results = strategy.findAllFirstMatches(recognizers, columns); + + // Then + Assert.assertEquals("Should return results for all columns", 2, results.size()); + Assert.assertTrue("Should find result for column1", results.get(strategy.getColumnKey(column1)).isPresent()); + Assert.assertFalse("Should not find result for column2", + results.get(strategy.getColumnKey(column2)).isPresent()); + + // Verify that recognizeBatch was not called for multiple recognizers + Mockito.verify(mockRecognizer1, Mockito.never()).recognizeBatch(Mockito.any()); + Mockito.verify(mockRecognizer2, Mockito.never()).recognizeBatch(Mockito.any()); + } + + @Test + public void test_getColumnKey_returnsCorrectFormat() { + // Given + DBTableColumn column = createTestColumn("user_phone", "varchar", "phone"); + column.setSchemaName("test_schema"); + column.setTableName("test_table"); + + // When + String key = strategy.getColumnKey(column); + + // Then + Assert.assertEquals("Should return correct column key format", "test_schema.test_table.user_phone", key); + } + + private DBTableColumn createTestColumn(String columnName, String typeName, String comment) { + DBTableColumn column = new DBTableColumn(); + column.setName(columnName); + column.setTypeName(typeName); + column.setComment(comment); + column.setSchemaName("test_schema"); + column.setTableName("test_table"); + return column; + } + + private RecognitionResult createRecognitionResult(Long ruleId, SensitiveLevel level) { + return RecognitionResult.builder() + .matched(true) + .matchedRuleId(ruleId) + .level(level) + .sourceRuleType(SensitiveRuleType.REGEX) + .build(); + } + + // Testable implementation of AbstractScanningStrategy for testing protected methods + private static class TestableAbstractScanningStrategy extends AbstractScanningStrategy { + @Override + public com.oceanbase.odc.service.datasecurity.model.ScanResult scan( + DBTableColumn column, + List basicRecognizers, + List aiRecognizers) { + return null; // Not used in these tests + } + + @Override + public Map scanBatch( + List columns, + List basicRecognizers, + List aiRecognizers) { + return null; // Not used in these tests + } + + // Expose protected methods for testing + @Override + public Optional findFirstMatch(List recognizers, DBTableColumn column) { + return super.findFirstMatch(recognizers, column); + } + + @Override + public Map> findAllFirstMatches(List recognizers, + List columns) { + return super.findAllFirstMatches(recognizers, columns); + } + + @Override + public String getColumnKey(DBTableColumn column) { + return super.getColumnKey(column); + } + } +} diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/JointRecognitionStrategyTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/JointRecognitionStrategyTest.java new file mode 100644 index 0000000000..bb94030861 --- /dev/null +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/JointRecognitionStrategyTest.java @@ -0,0 +1,295 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.strategy; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.ScanResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; +import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +@RunWith(MockitoJUnitRunner.class) +public class JointRecognitionStrategyTest { + + @Mock + private ColumnRecognizer mockBasicRecognizer; + + @Mock + private ColumnRecognizer mockAiRecognizer; + + private JointRecognitionStrategy strategy; + private List basicRecognizers; + private List aiRecognizers; + private DBTableColumn testColumn; + + @Before + public void setUp() { + strategy = new JointRecognitionStrategy(); + basicRecognizers = Arrays.asList(mockBasicRecognizer); + aiRecognizers = Arrays.asList(mockAiRecognizer); + testColumn = createTestColumn("user_phone", "varchar", "user phone number"); + } + + @Test + public void test_scan_withBasicRecognizerMatch_returnsBasicResultAndSkipsAi() { + // Given + RecognitionResult basicResult = createRecognitionResult(1L, SensitiveLevel.HIGH, SensitiveRuleType.REGEX); + + Mockito.when(mockBasicRecognizer.recognize(testColumn)).thenReturn(Optional.of(basicResult)); + + // When + ScanResult result = strategy.scan(testColumn, basicRecognizers, aiRecognizers); + + // Then + Assert.assertTrue("Should have basic rule result", result.getBasicRuleResult().isPresent()); + Assert.assertFalse("Should not have AI rule result", result.getAiRuleResult().isPresent()); + Assert.assertEquals("Should return basic result", basicResult, result.getBasicRuleResult().get()); + + // Verify AI recognizer is not called when basic recognizer matches + Mockito.verify(mockAiRecognizer, Mockito.never()).recognize(Mockito.any()); + } + + @Test + public void test_scan_withNoBasicMatchButAiMatch_returnsAiResult() { + // Given + RecognitionResult aiResult = createRecognitionResult(2L, SensitiveLevel.MEDIUM, SensitiveRuleType.AI); + + Mockito.when(mockBasicRecognizer.recognize(testColumn)).thenReturn(Optional.empty()); + Mockito.when(mockAiRecognizer.recognize(testColumn)).thenReturn(Optional.of(aiResult)); + + // When + ScanResult result = strategy.scan(testColumn, basicRecognizers, aiRecognizers); + + // Then + Assert.assertFalse("Should not have basic rule result", result.getBasicRuleResult().isPresent()); + Assert.assertTrue("Should have AI rule result", result.getAiRuleResult().isPresent()); + Assert.assertEquals("Should return AI result", aiResult, result.getAiRuleResult().get()); + + // Verify AI recognizer is called as fallback + Mockito.verify(mockAiRecognizer).recognize(testColumn); + } + + @Test + public void test_scan_withNoMatches_returnsEmptyResult() { + // Given + Mockito.when(mockBasicRecognizer.recognize(testColumn)).thenReturn(Optional.empty()); + Mockito.when(mockAiRecognizer.recognize(testColumn)).thenReturn(Optional.empty()); + + // When + ScanResult result = strategy.scan(testColumn, basicRecognizers, aiRecognizers); + + // Then + Assert.assertFalse("Should not have basic rule result", result.getBasicRuleResult().isPresent()); + Assert.assertFalse("Should not have AI rule result", result.getAiRuleResult().isPresent()); + + // Verify both recognizers are called + Mockito.verify(mockBasicRecognizer).recognize(testColumn); + Mockito.verify(mockAiRecognizer).recognize(testColumn); + } + + @Test + public void test_scanBatch_withMixedResults_optimizesAiCalls() { + // Given + DBTableColumn column1 = createTestColumn("user_phone", "varchar", "phone"); + DBTableColumn column2 = createTestColumn("user_email", "varchar", "email"); + DBTableColumn column3 = createTestColumn("user_name", "varchar", "name"); + DBTableColumn column4 = createTestColumn("user_id", "bigint", "id"); + List columns = Arrays.asList(column1, column2, column3, column4); + + // Basic recognizer matches column1 and column3 + RecognitionResult basicResult1 = createRecognitionResult(1L, SensitiveLevel.HIGH, SensitiveRuleType.REGEX); + RecognitionResult basicResult3 = createRecognitionResult(3L, SensitiveLevel.LOW, SensitiveRuleType.GROOVY); + + // Mock recognizeBatch for basic recognizer (single recognizer scenario) + Map> basicBatchResults = new HashMap<>(); + basicBatchResults.put(getColumnKey(column1), Optional.of(basicResult1)); + basicBatchResults.put(getColumnKey(column2), Optional.empty()); + basicBatchResults.put(getColumnKey(column3), Optional.of(basicResult3)); + basicBatchResults.put(getColumnKey(column4), Optional.empty()); + Mockito.when(mockBasicRecognizer.recognizeBatch(columns)).thenReturn(basicBatchResults); + + // AI recognizer only matches column2 (column4 has no match) + RecognitionResult aiResult2 = createRecognitionResult(2L, SensitiveLevel.MEDIUM, SensitiveRuleType.AI); + + // Mock recognizeBatch for AI recognizer on remaining columns (column2, column4) + List remainingColumns = Arrays.asList(column2, column4); + Map> aiBatchResults = new HashMap<>(); + aiBatchResults.put(getColumnKey(column2), Optional.of(aiResult2)); + aiBatchResults.put(getColumnKey(column4), Optional.empty()); + Mockito.when(mockAiRecognizer.recognizeBatch(remainingColumns)).thenReturn(aiBatchResults); + + // When + Map results = strategy.scanBatch(columns, basicRecognizers, aiRecognizers); + + // Then + Assert.assertEquals("Should return results for all columns", 4, results.size()); + + String key1 = getColumnKey(column1); + String key2 = getColumnKey(column2); + String key3 = getColumnKey(column3); + String key4 = getColumnKey(column4); + + // Column1: basic match, no AI call + Assert.assertTrue("Column1 should have basic result", results.get(key1).getBasicRuleResult().isPresent()); + Assert.assertFalse("Column1 should not have AI result", results.get(key1).getAiRuleResult().isPresent()); + + // Column2: no basic match, AI match + Assert.assertFalse("Column2 should not have basic result", results.get(key2).getBasicRuleResult().isPresent()); + Assert.assertTrue("Column2 should have AI result", results.get(key2).getAiRuleResult().isPresent()); + + // Column3: basic match, no AI call + Assert.assertTrue("Column3 should have basic result", results.get(key3).getBasicRuleResult().isPresent()); + Assert.assertFalse("Column3 should not have AI result", results.get(key3).getAiRuleResult().isPresent()); + + // Column4: no matches + Assert.assertFalse("Column4 should not have basic result", results.get(key4).getBasicRuleResult().isPresent()); + Assert.assertFalse("Column4 should not have AI result", results.get(key4).getAiRuleResult().isPresent()); + + // Verify AI recognizer is only called for columns without basic matches + Mockito.verify(mockAiRecognizer).recognizeBatch(remainingColumns); + Mockito.verify(mockAiRecognizer, Mockito.never()).recognize(Mockito.any()); + } + + @Test + public void test_scanBatch_withAllBasicMatches_skipsAiCompletely() { + // Given + DBTableColumn column1 = createTestColumn("user_phone", "varchar", "phone"); + DBTableColumn column2 = createTestColumn("user_email", "varchar", "email"); + List columns = Arrays.asList(column1, column2); + + RecognitionResult basicResult1 = createRecognitionResult(1L, SensitiveLevel.HIGH, SensitiveRuleType.REGEX); + RecognitionResult basicResult2 = createRecognitionResult(2L, SensitiveLevel.MEDIUM, SensitiveRuleType.GROOVY); + + // Mock recognizeBatch for basic recognizer (all matches) + Map> basicBatchResults = new HashMap<>(); + basicBatchResults.put(getColumnKey(column1), Optional.of(basicResult1)); + basicBatchResults.put(getColumnKey(column2), Optional.of(basicResult2)); + Mockito.when(mockBasicRecognizer.recognizeBatch(columns)).thenReturn(basicBatchResults); + + // When + Map results = strategy.scanBatch(columns, basicRecognizers, aiRecognizers); + + // Then + Assert.assertEquals("Should return results for all columns", 2, results.size()); + + // Verify AI recognizer is never called + Mockito.verify(mockAiRecognizer, Mockito.never()).recognize(Mockito.any()); + Mockito.verify(mockAiRecognizer, Mockito.never()).recognizeBatch(Mockito.any()); + } + + @Test + public void test_scanBatch_withNoBasicMatches_callsAiForAll() { + // Given + DBTableColumn column1 = createTestColumn("user_phone", "varchar", "phone"); + DBTableColumn column2 = createTestColumn("user_email", "varchar", "email"); + List columns = Arrays.asList(column1, column2); + + // Mock recognizeBatch for basic recognizer (no matches) + Map> basicBatchResults = new HashMap<>(); + basicBatchResults.put(getColumnKey(column1), Optional.empty()); + basicBatchResults.put(getColumnKey(column2), Optional.empty()); + Mockito.when(mockBasicRecognizer.recognizeBatch(columns)).thenReturn(basicBatchResults); + + RecognitionResult aiResult1 = createRecognitionResult(1L, SensitiveLevel.HIGH, SensitiveRuleType.AI); + + // Mock recognizeBatch for AI recognizer (all columns since no basic matches) + Map> aiBatchResults = new HashMap<>(); + aiBatchResults.put(getColumnKey(column1), Optional.of(aiResult1)); + aiBatchResults.put(getColumnKey(column2), Optional.empty()); + Mockito.when(mockAiRecognizer.recognizeBatch(columns)).thenReturn(aiBatchResults); + + // When + Map results = strategy.scanBatch(columns, basicRecognizers, aiRecognizers); + + // Then + Assert.assertEquals("Should return results for all columns", 2, results.size()); + + // Verify AI recognizer is called for all columns + Mockito.verify(mockAiRecognizer).recognizeBatch(columns); + Mockito.verify(mockAiRecognizer, Mockito.never()).recognize(Mockito.any()); + } + + @Test + public void test_scanBatch_withEmptyColumns_returnsEmptyMap() { + // Given + List emptyColumns = Collections.emptyList(); + + // When + Map results = strategy.scanBatch(emptyColumns, basicRecognizers, aiRecognizers); + + // Then + Assert.assertTrue("Should return empty map", results.isEmpty()); + } + + @Test + public void test_scanBatch_withEmptyRecognizers_returnsEmptyResults() { + // Given + List columns = Arrays.asList(testColumn); + List emptyBasicRecognizers = Collections.emptyList(); + List emptyAiRecognizers = Collections.emptyList(); + + // When + Map results = strategy.scanBatch(columns, emptyBasicRecognizers, emptyAiRecognizers); + + // Then + Assert.assertEquals("Should return results for all columns", 1, results.size()); + String key = getColumnKey(testColumn); + Assert.assertFalse("Should not have basic result", results.get(key).getBasicRuleResult().isPresent()); + Assert.assertFalse("Should not have AI result", results.get(key).getAiRuleResult().isPresent()); + } + + private DBTableColumn createTestColumn(String columnName, String typeName, String comment) { + DBTableColumn column = new DBTableColumn(); + column.setName(columnName); + column.setTypeName(typeName); + column.setComment(comment); + column.setSchemaName("test_schema"); + column.setTableName("test_table"); + return column; + } + + private RecognitionResult createRecognitionResult(Long ruleId, SensitiveLevel level, SensitiveRuleType ruleType) { + return RecognitionResult.builder() + .matched(true) + .matchedRuleId(ruleId) + .level(level) + .sourceRuleType(ruleType) + .build(); + } + + private String getColumnKey(DBTableColumn column) { + return String.format("%s.%s.%s", + column.getSchemaName() != null ? column.getSchemaName() : "unknown_schema", + column.getTableName() != null ? column.getTableName() : "unknown_table", + column.getName() != null ? column.getName() : "unknown_column"); + } +} diff --git a/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/RulesOnlyStrategyTest.java b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/RulesOnlyStrategyTest.java new file mode 100644 index 0000000000..29940e30a5 --- /dev/null +++ b/server/odc-service/src/test/java/com/oceanbase/odc/service/datasecurity/strategy/RulesOnlyStrategyTest.java @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.service.datasecurity.strategy; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import com.oceanbase.odc.service.datasecurity.model.RecognitionResult; +import com.oceanbase.odc.service.datasecurity.model.ScanResult; +import com.oceanbase.odc.service.datasecurity.model.SensitiveLevel; +import com.oceanbase.odc.service.datasecurity.model.SensitiveRuleType; +import com.oceanbase.odc.service.datasecurity.recognizer.ColumnRecognizer; +import com.oceanbase.tools.dbbrowser.model.DBTableColumn; + +@RunWith(MockitoJUnitRunner.class) +public class RulesOnlyStrategyTest { + + @Mock + private ColumnRecognizer mockBasicRecognizer; + + @Mock + private ColumnRecognizer mockAiRecognizer; + + private RulesOnlyStrategy strategy; + private List basicRecognizers; + private List aiRecognizers; + private DBTableColumn testColumn; + + @Before + public void setUp() { + strategy = new RulesOnlyStrategy(); + basicRecognizers = Arrays.asList(mockBasicRecognizer); + aiRecognizers = Arrays.asList(mockAiRecognizer); + testColumn = createTestColumn("user_phone", "varchar", "user phone number"); + } + + @Test + public void test_scan_withBasicRecognizerMatch_returnsBasicResultOnly() { + // Given + RecognitionResult basicResult = createRecognitionResult(1L, SensitiveLevel.HIGH, SensitiveRuleType.REGEX); + + Mockito.when(mockBasicRecognizer.recognize(testColumn)).thenReturn(Optional.of(basicResult)); + + // When + ScanResult result = strategy.scan(testColumn, basicRecognizers, aiRecognizers); + + // Then + Assert.assertTrue("Should have basic rule result", result.getBasicRuleResult().isPresent()); + Assert.assertFalse("Should not have AI rule result", result.getAiRuleResult().isPresent()); + Assert.assertEquals("Should return basic result", basicResult, result.getBasicRuleResult().get()); + + // Verify AI recognizer is not called (rules only strategy) + Mockito.verify(mockAiRecognizer, Mockito.never()).recognize(Mockito.any()); + } + + @Test + public void test_scan_withNoBasicRecognizerMatch_returnsEmptyResult() { + // Given + Mockito.when(mockBasicRecognizer.recognize(testColumn)).thenReturn(Optional.empty()); + + // When + ScanResult result = strategy.scan(testColumn, basicRecognizers, aiRecognizers); + + // Then + Assert.assertFalse("Should not have basic rule result", result.getBasicRuleResult().isPresent()); + Assert.assertFalse("Should not have AI rule result", result.getAiRuleResult().isPresent()); + + // Verify AI recognizer is not called + Mockito.verify(mockAiRecognizer, Mockito.never()).recognize(Mockito.any()); + } + + @Test + public void test_scan_withEmptyBasicRecognizers_returnsEmptyResult() { + // Given + List emptyBasicRecognizers = Collections.emptyList(); + + // When + ScanResult result = strategy.scan(testColumn, emptyBasicRecognizers, aiRecognizers); + + // Then + Assert.assertFalse("Should not have basic rule result", result.getBasicRuleResult().isPresent()); + Assert.assertFalse("Should not have AI rule result", result.getAiRuleResult().isPresent()); + } + + @Test + public void test_scanBatch_withMixedResults_returnsCorrectMapping() { + // Given + DBTableColumn column1 = createTestColumn("user_phone", "varchar", "phone"); + DBTableColumn column2 = createTestColumn("user_email", "varchar", "email"); + DBTableColumn column3 = createTestColumn("user_name", "varchar", "name"); + List columns = Arrays.asList(column1, column2, column3); + + RecognitionResult result1 = createRecognitionResult(1L, SensitiveLevel.HIGH, SensitiveRuleType.REGEX); + RecognitionResult result3 = createRecognitionResult(3L, SensitiveLevel.LOW, SensitiveRuleType.GROOVY); + + // Mock recognizeBatch for single recognizer scenario + Map> batchResults = new HashMap<>(); + batchResults.put(getColumnKey(column1), Optional.of(result1)); + batchResults.put(getColumnKey(column2), Optional.empty()); + batchResults.put(getColumnKey(column3), Optional.of(result3)); + Mockito.when(mockBasicRecognizer.recognizeBatch(columns)).thenReturn(batchResults); + + // When + Map results = strategy.scanBatch(columns, basicRecognizers, aiRecognizers); + + // Then + Assert.assertEquals("Should return results for all columns", 3, results.size()); + + String key1 = getColumnKey(column1); + String key2 = getColumnKey(column2); + String key3 = getColumnKey(column3); + + Assert.assertTrue("Column1 should have basic result", results.get(key1).getBasicRuleResult().isPresent()); + Assert.assertFalse("Column1 should not have AI result", results.get(key1).getAiRuleResult().isPresent()); + + Assert.assertFalse("Column2 should not have basic result", results.get(key2).getBasicRuleResult().isPresent()); + Assert.assertFalse("Column2 should not have AI result", results.get(key2).getAiRuleResult().isPresent()); + + Assert.assertTrue("Column3 should have basic result", results.get(key3).getBasicRuleResult().isPresent()); + Assert.assertFalse("Column3 should not have AI result", results.get(key3).getAiRuleResult().isPresent()); + + // Verify AI recognizer is never called + Mockito.verify(mockAiRecognizer, Mockito.never()).recognize(Mockito.any()); + Mockito.verify(mockAiRecognizer, Mockito.never()).recognizeBatch(Mockito.any()); + } + + @Test + public void test_scanBatch_withEmptyColumns_returnsEmptyMap() { + // Given + List emptyColumns = Collections.emptyList(); + + // When + Map results = strategy.scanBatch(emptyColumns, basicRecognizers, aiRecognizers); + + // Then + Assert.assertTrue("Should return empty map", results.isEmpty()); + } + + @Test + public void test_scanBatch_withEmptyBasicRecognizers_returnsEmptyResults() { + // Given + List columns = Arrays.asList(testColumn); + List emptyBasicRecognizers = Collections.emptyList(); + + // When + Map results = strategy.scanBatch(columns, emptyBasicRecognizers, aiRecognizers); + + // Then + Assert.assertEquals("Should return results for all columns", 1, results.size()); + String key = getColumnKey(testColumn); + Assert.assertFalse("Should not have basic result", results.get(key).getBasicRuleResult().isPresent()); + Assert.assertFalse("Should not have AI result", results.get(key).getAiRuleResult().isPresent()); + } + + private DBTableColumn createTestColumn(String columnName, String typeName, String comment) { + DBTableColumn column = new DBTableColumn(); + column.setName(columnName); + column.setTypeName(typeName); + column.setComment(comment); + column.setSchemaName("test_schema"); + column.setTableName("test_table"); + return column; + } + + private RecognitionResult createRecognitionResult(Long ruleId, SensitiveLevel level, SensitiveRuleType ruleType) { + return RecognitionResult.builder() + .matched(true) + .matchedRuleId(ruleId) + .level(level) + .sourceRuleType(ruleType) + .build(); + } + + private String getColumnKey(DBTableColumn column) { + return String.format("%s.%s.%s", + column.getSchemaName() != null ? column.getSchemaName() : "unknown_schema", + column.getTableName() != null ? column.getTableName() : "unknown_table", + column.getName() != null ? column.getName() : "unknown_column"); + } +}