From d869af530a14f49eafe14b7d6b1290774ece5078 Mon Sep 17 00:00:00 2001 From: Prashant Pandey Date: Thu, 30 Apr 2026 15:14:16 +0530 Subject: [PATCH 1/2] Batch updated by key groups --- .../FlatCollectionWriteTest.java | 118 +++++++++ .../postgres/FlatPostgresCollection.java | 230 ++++++++++++++---- 2 files changed, 307 insertions(+), 41 deletions(-) diff --git a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java index a52a55b2..c7c2ffd1 100644 --- a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java +++ b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java @@ -3531,6 +3531,124 @@ void testBulkUpdateAllOperatorTypes() throws Exception { } } + @Test + @DisplayName( + "Should efficiently batch updates across multiple key groups with complex operations") + void testBulkUpdateMultipleGroupsComplexOperations() throws Exception { + Map> updates = new LinkedHashMap<>(); + + // ===== Group 1: Top-level primitive + top-level array (3 keys: 1, 5, 8) ===== + // All have item="Soap" - these should be batched together + // This tests: SET on primitive field, APPEND_TO_LIST on array field + List group1Updates = + List.of( + SubDocumentUpdate.of("price", 99), // SET operator (top-level primitive) + SubDocumentUpdate.builder() + .subDocument("tags") + .operator(UpdateOperator.APPEND_TO_LIST) + .subDocumentValue(SubDocumentValue.of(new String[] {"updated-tag", "batch-test"})) + .build()); // APPEND_TO_LIST on top-level array + + updates.put(rawKey("1"), group1Updates); + updates.put(rawKey("5"), group1Updates); + updates.put(rawKey("8"), group1Updates); + + // ===== Group 2: Nested JSONB updates (2 keys: 3, 7) ===== + // Both have props - these should be batched together + // This tests: SET on nested JSONB fields + List group2Updates = + List.of( + SubDocumentUpdate.builder() + .subDocument("props.brand") + .operator(UpdateOperator.SET) + .subDocumentValue(SubDocumentValue.of("PremiumBrand")) + .build(), // SET on nested JSONB primitive + SubDocumentUpdate.builder() + .subDocument("props.size") + .operator(UpdateOperator.SET) + .subDocumentValue(SubDocumentValue.of("XL")) + .build()); // SET on another nested field + + updates.put(rawKey("3"), group2Updates); + updates.put(rawKey("7"), group2Updates); + + // ===== Group 3: ADD operator + REMOVE_ALL_FROM_LIST (2 keys: 2, 6) ===== + // Both have quantity and tags - these should be batched together + // This tests: ADD on numeric field, REMOVE_ALL_FROM_LIST on array + List group3Updates = + List.of( + SubDocumentUpdate.builder() + .subDocument("quantity") + .operator(UpdateOperator.ADD) + .subDocumentValue(SubDocumentValue.of(100)) + .build(), // ADD to numeric field + SubDocumentUpdate.builder() + .subDocument("tags") + .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) + .subDocumentValue(SubDocumentValue.of(new String[] {"glass", "plastic"})) + .build()); // REMOVE_ALL_FROM_LIST + + updates.put(rawKey("2"), group3Updates); + updates.put(rawKey("6"), group3Updates); + + // Execute bulk update - should have 3 groups with 2-3 keys each + BulkUpdateResult result = flatCollection.bulkUpdate(updates, UpdateOptions.builder().build()); + + // Total unique keys: 1, 2, 3, 5, 6, 7, 8 = 7 keys + assertEquals(7, result.getUpdatedCount(), "Should update 7 rows"); + + // Verify keys 1, 5, 8 have Group 1 updates (top-level primitive + array) + for (String id : List.of("1", "5", "8")) { + try (CloseableIterator iter = flatCollection.find(queryById(id))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + assertEquals(99, json.get("price").asInt(), "Key " + id + " price should be 99"); + JsonNode tags = json.get("tags"); + List tagList = new ArrayList<>(); + tags.forEach(t -> tagList.add(t.asText())); + assertTrue( + tagList.contains("updated-tag"), "Key " + id + " should contain 'updated-tag'"); + assertTrue(tagList.contains("batch-test"), "Key " + id + " should contain 'batch-test'"); + } + } + + // Verify keys 3, 7 have Group 2 updates (nested JSONB) + for (String id : List.of("3", "7")) { + try (CloseableIterator iter = flatCollection.find(queryById(id))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + JsonNode props = json.get("props"); + assertNotNull(props, "Key " + id + " should have props"); + assertEquals( + "PremiumBrand", + props.get("brand").asText(), + "Key " + id + " brand should be updated"); + assertEquals("XL", props.get("size").asText(), "Key " + id + " size should be XL"); + } + } + + // Verify keys 2, 6 have Group 3 updates (ADD + REMOVE_ALL_FROM_LIST) + try (CloseableIterator iter = flatCollection.find(queryById("2"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + assertEquals(101, json.get("quantity").asInt()); // 1 + 100 + JsonNode tags = json.get("tags"); + List tagList = new ArrayList<>(); + tags.forEach(t -> tagList.add(t.asText())); + assertFalse(tagList.contains("glass"), "Key 2 should not have 'glass' tag"); + } + + try (CloseableIterator iter = flatCollection.find(queryById("6"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + assertEquals(105, json.get("quantity").asInt()); // 5 + 100 + JsonNode tags = json.get("tags"); + List tagList = new ArrayList<>(); + tags.forEach(t -> tagList.add(t.asText())); + assertFalse(tagList.contains("plastic"), "Key 6 should not have 'plastic' tag"); + } + } + @Test @DisplayName("Should handle edge cases: empty map, null map, non-existent keys") void testBulkUpdateEdgeCases() throws Exception { diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java index 440b295b..3cef66cc 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -874,59 +875,34 @@ public BulkUpdateResult bulkUpdate( String tableName = tableIdentifier.getTableName(); String quotedPkColumn = PostgresUtils.wrapFieldNamesWithDoubleQuotes(getPKForTable(tableName)); - - Set updatedKeys = new HashSet<>(); - long batchUpdateTimestamp = System.currentTimeMillis(); - try (Connection connection = client.getPooledConnection()) { - for (Map.Entry> entry : updates.entrySet()) { - Key key = entry.getKey(); - Collection keyUpdates = entry.getValue(); + // Group keys by their "SQL shape" (same update operations) + Map keyGroups = groupKeysByUpdateShape(updates, tableName); - if (keyUpdates == null || keyUpdates.isEmpty()) { - continue; - } + int totalUpdated = 0; + try (Connection connection = client.getPooledConnection()) { + // Execute one multi-row UPDATE per group (or fallback to single-key if group size = 1) + for (Map.Entry entry : keyGroups.entrySet()) { try { - boolean updated = - updateSingleKey( - connection, key, keyUpdates, tableName, quotedPkColumn, batchUpdateTimestamp); - if (updated) { - updatedKeys.add(key); - } + int updated = + executeBatchUpdate( + connection, entry.getValue(), tableName, quotedPkColumn, batchUpdateTimestamp); + totalUpdated += updated; } catch (Exception e) { - LOGGER.warn("Failed to update key {}: {}", key, e.getMessage()); - // Continue with other keys - no cross-key atomicity + LOGGER.warn( + "Failed to update key group (size: {}): {}", + entry.getValue().getKeys().size(), + e.getMessage()); + // Continue with other groups - no cross-group atomicity } } } catch (SQLException e) { throw new IOException("Failed to get connection for bulk update", e); } - return new BulkUpdateResult(updatedKeys.size()); - } - - private boolean updateSingleKey( - Connection connection, - Key key, - Collection keyUpdates, - String tableName, - String quotedPkColumn, - long keyUpdateTimestamp) - throws IOException, SQLException { - - updateValidator.validate(keyUpdates); - Map resolvedColumns = resolvePathsToColumns(keyUpdates, tableName); - - return executeKeyUpdate( - connection, - key, - keyUpdates, - tableName, - quotedPkColumn, - resolvedColumns, - keyUpdateTimestamp); + return new BulkUpdateResult(totalUpdated); } private boolean executeKeyUpdate( @@ -972,6 +948,178 @@ private boolean executeKeyUpdate( } } + /** + * Groups keys that have identical update operations together. Keys with the same "shape" can be + * updated in a single multi-row statement. + */ + private Map groupKeysByUpdateShape( + Map> updates, String tableName) { + + Map groups = new LinkedHashMap<>(); + + for (Map.Entry> entry : updates.entrySet()) { + Key key = entry.getKey(); + Collection keyUpdates = entry.getValue(); + + if (keyUpdates == null || keyUpdates.isEmpty()) { + continue; + } + + try { + updateValidator.validate(keyUpdates); + Map resolvedColumns = resolvePathsToColumns(keyUpdates, tableName); + + String shapeKey = computeUpdateShapeKey(keyUpdates, resolvedColumns); + + groups + .computeIfAbsent(shapeKey, k -> new KeyUpdateGroup(resolvedColumns)) + .addKeyWithUpdates(key, keyUpdates); + + } catch (Exception e) { + LOGGER.warn("Failed to group key {}: {}", key, e.getMessage()); + } + } + + return groups; + } + + private String computeUpdateShapeKey( + Collection updates, Map resolvedColumns) { + + List sorted = new ArrayList<>(updates); + sorted.sort(Comparator.comparing(u -> u.getSubDocument().getPath())); + + StringBuilder sb = new StringBuilder(); + for (SubDocumentUpdate update : sorted) { + String path = update.getSubDocument().getPath(); + String column = resolvedColumns.get(path); + sb.append(column) + .append(":") + .append(update.getOperator()) + .append(":") + .append(path) + .append(";"); + } + + return sb.toString(); + } + + /** + * Executes a batch UPDATE for all keys in the group using JDBC batching. All keys in the group + * share the same SQL structure, so we can use a single PreparedStatement. + */ + private int executeBatchUpdate( + Connection connection, + KeyUpdateGroup keyGroup, + String tableName, + String quotedPkColumn, + long epochMillis) + throws SQLException { + + List keys = keyGroup.getKeys(); + List> allKeyUpdates = keyGroup.getKeyUpdates(); + Map resolvedColumns = keyGroup.getResolvedColumns(); + + // Use the first key's updates to build the SQL template + Collection templateUpdates = allKeyUpdates.get(0); + List setFragments = new ArrayList<>(); + List templateParams = new ArrayList<>(); + + boolean hasUpdates = + buildSetClauseFragments( + connection, templateUpdates, tableName, resolvedColumns, setFragments, templateParams); + + if (!hasUpdates) { + return 0; + } + + appendLastUpdatedTimestamp(setFragments, templateParams, tableName, epochMillis); + + // Build UPDATE SQL (same for all keys in this group) + String sql = + String.format( + "UPDATE %s SET %s WHERE %s = ?", + tableIdentifier, String.join(", ", setFragments), quotedPkColumn); + + LOGGER.debug("Executing batch update SQL: {} for {} keys", sql, keys.size()); + + // Use JDBC batching to execute all updates in one round-trip + try (PreparedStatement ps = connection.prepareStatement(sql)) { + for (int i = 0; i < keys.size(); i++) { + Key key = keys.get(i); + Collection keyUpdates = allKeyUpdates.get(i); + + // Build parameters for this specific key + List keySetFragments = new ArrayList<>(); + List keyParams = new ArrayList<>(); + buildSetClauseFragments( + connection, keyUpdates, tableName, resolvedColumns, keySetFragments, keyParams); + + // Add timestamp parameter + if (lastUpdatedTsColumn != null) { + Optional colMeta = + schemaRegistry.getColumnOrRefresh(tableName, lastUpdatedTsColumn); + if (colMeta.isPresent()) { + Object timestampValue = + convertTimestampForType(epochMillis, colMeta.get().getPostgresType()); + keyParams.add(timestampValue); + } + } + + // Bind parameters for this key + int idx = 1; + for (Object param : keyParams) { + ps.setObject(idx++, param); + } + ps.setObject(idx, key.toString()); // WHERE clause parameter + + ps.addBatch(); + } + + int[] results = ps.executeBatch(); + int totalUpdated = 0; + for (int result : results) { + if (result > 0) { + totalUpdated++; + } + } + + LOGGER.debug("Batch update affected {} rows out of {} keys", totalUpdated, keys.size()); + return totalUpdated; + } catch (SQLException e) { + LOGGER.warn("Failed to execute batch update. SQL: {}, Error: {}", sql, e.getMessage()); + throw e; + } + } + + /** Holds a group of keys that share the same update shape. */ + private static class KeyUpdateGroup { + private final Map resolvedColumns; + private final List keys = new ArrayList<>(); + private final List> keyUpdates = new ArrayList<>(); + + KeyUpdateGroup(Map resolvedColumns) { + this.resolvedColumns = resolvedColumns; + } + + void addKeyWithUpdates(Key key, Collection updates) { + keys.add(key); + keyUpdates.add(updates); + } + + Map getResolvedColumns() { + return resolvedColumns; + } + + List getKeys() { + return keys; + } + + List> getKeyUpdates() { + return keyUpdates; + } + } + /** * Validates all updates and resolves column names. * From 5db4adbeae3991f4813ac1e4804e10d5ce43390c Mon Sep 17 00:00:00 2001 From: Prashant Pandey Date: Tue, 19 May 2026 01:46:09 +0530 Subject: [PATCH 2/2] WIP --- .../FlatCollectionWriteTest.java | 84 ++++++++--- .../postgres/FlatPostgresCollection.java | 134 +++++++----------- 2 files changed, 119 insertions(+), 99 deletions(-) diff --git a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java index 6d13302a..b564e1b2 100644 --- a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java +++ b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java @@ -2290,15 +2290,15 @@ void testSetAllFieldTypes() throws Exception { SubDocumentUpdate.of("rating", 4.5f), SubDocumentUpdate.of("weight", 123.456), // Case 2: Top-level arrays - SubDocumentUpdate.of("tags", new String[] {"tag4", "tag5", "tag6"}), - SubDocumentUpdate.of("numbers", new Integer[] {10, 20, 30}), - SubDocumentUpdate.of("scores", new Double[] {1.1, 2.2, 3.3}), - SubDocumentUpdate.of("flags", new Boolean[] {true, false, true}), + SubDocumentUpdate.of("tags", new String[]{"tag4", "tag5", "tag6"}), + SubDocumentUpdate.of("numbers", new Integer[]{10, 20, 30}), + SubDocumentUpdate.of("scores", new Double[]{1.1, 2.2, 3.3}), + SubDocumentUpdate.of("flags", new Boolean[]{true, false, true}), // Case 3 & 4: One nested path in JSONB (props) - tests nested primitive SubDocumentUpdate.of("props.brand", "NewBrand"), // Use 'sales' JSONB column for nested array test SubDocumentUpdate.of( - "sales.regions", SubDocumentValue.of(new String[] {"US", "EU", "APAC"}))); + "sales.regions", SubDocumentValue.of(new String[]{"US", "EU", "APAC"}))); UpdateOptions options = UpdateOptions.builder().returnDocumentType(ReturnDocumentType.AFTER_UPDATE).build(); @@ -2510,7 +2510,7 @@ void testSetMultipleNestedPathsInSameJsonbColumn() throws Exception { SubDocumentUpdate.of("props.size", "XL"), SubDocumentUpdate.of("props.newField", "newValue"), SubDocumentUpdate.of( - "props.owners", SubDocumentValue.of(new String[] {"owner1", "owner2"}))); + "props.owners", SubDocumentValue.of(new String[]{"owner1", "owner2"}))); UpdateOptions options = UpdateOptions.builder().returnDocumentType(ReturnDocumentType.AFTER_UPDATE).build(); @@ -2818,7 +2818,7 @@ void testAddArrayValue() { SubDocumentUpdate.builder() .subDocument("price") .operator(UpdateOperator.ADD) - .subDocumentValue(SubDocumentValue.of(new Integer[] {1, 2, 3})) + .subDocumentValue(SubDocumentValue.of(new Integer[]{1, 2, 3})) .build()); UpdateOptions options = @@ -2865,19 +2865,19 @@ void testAppendToListAllCases() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.APPEND_TO_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"newTag1", "newTag2"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"newTag1", "newTag2"})) .build(), // Nested JSONB array: append to existing props.colors SubDocumentUpdate.builder() .subDocument("props.colors") .operator(UpdateOperator.APPEND_TO_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"green", "yellow"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"green", "yellow"})) .build(), // Nested JSONB: append to non-existent array (creates it) SubDocumentUpdate.builder() .subDocument("sales.regions") .operator(UpdateOperator.APPEND_TO_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"US", "EU"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"US", "EU"})) .build()); UpdateOptions options = @@ -2956,13 +2956,13 @@ void testAddToListIfAbsentAllCases() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.ADD_TO_LIST_IF_ABSENT) - .subDocumentValue(SubDocumentValue.of(new String[] {"existing1", "newTag"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"existing1", "newTag"})) .build(), // Nested JSONB: 'red' exists, 'green' is new → adds only 'green' SubDocumentUpdate.builder() .subDocument("props.colors") .operator(UpdateOperator.ADD_TO_LIST_IF_ABSENT) - .subDocumentValue(SubDocumentValue.of(new String[] {"red", "green"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"red", "green"})) .build()); UpdateOptions options = @@ -3033,13 +3033,13 @@ void testRemoveAllFromListAllCases() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"tag1"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"tag1"})) .build(), // Nested JSONB: remove 'red' and 'blue' → leaves green SubDocumentUpdate.builder() .subDocument("props.colors") .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"red", "blue"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"red", "blue"})) .build()); UpdateOptions options = @@ -3512,7 +3512,7 @@ void testBulkUpdateAllOperatorTypes() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.APPEND_TO_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"newTag1", "newTag2"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"newTag1", "newTag2"})) .build())); updates.put( @@ -3521,7 +3521,7 @@ void testBulkUpdateAllOperatorTypes() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.ADD_TO_LIST_IF_ABSENT) - .subDocumentValue(SubDocumentValue.of(new String[] {"hygiene", "uniqueTag"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"hygiene", "uniqueTag"})) .build())); updates.put( @@ -3530,7 +3530,7 @@ void testBulkUpdateAllOperatorTypes() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"plastic"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"plastic"})) .build())); BulkUpdateResult result = flatCollection.bulkUpdate(updates, UpdateOptions.builder().build()); @@ -3591,7 +3591,7 @@ void testBulkUpdateMultipleGroupsComplexOperations() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.APPEND_TO_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"updated-tag", "batch-test"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"updated-tag", "batch-test"})) .build()); // APPEND_TO_LIST on top-level array updates.put(rawKey("1"), group1Updates); @@ -3630,7 +3630,7 @@ void testBulkUpdateMultipleGroupsComplexOperations() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"glass", "plastic"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"glass", "plastic"})) .build()); // REMOVE_ALL_FROM_LIST updates.put(rawKey("2"), group3Updates); @@ -3694,6 +3694,52 @@ void testBulkUpdateMultipleGroupsComplexOperations() throws Exception { } } + @Test + @DisplayName( + "Should batch keys whose update shape matches by column:operator:path but whose value " + + "arrays differ in length (nested JSONB REMOVE_ALL_FROM_LIST)") + void testBulkUpdateSameShapeDifferentParamCardinality() throws Exception { + Map> updates = new LinkedHashMap<>(); + + updates.put( + rawKey("1"), + List.of( + SubDocumentUpdate.builder() + .subDocument("props.colors") + .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) + .subDocumentValue(SubDocumentValue.of(new String[]{"Blue"})) + .build())); + + updates.put( + rawKey("5"), + List.of( + SubDocumentUpdate.builder() + .subDocument("props.colors") + .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) + .subDocumentValue(SubDocumentValue.of(new String[]{"Orange", "Blue"})) + .build())); + + BulkUpdateResult result = flatCollection.bulkUpdate(updates, UpdateOptions.builder().build()); + + assertEquals(2, result.getUpdatedCount()); + + try (CloseableIterator iter = flatCollection.find(queryById("1"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + JsonNode colors = json.get("props").get("colors"); + List colorList = new ArrayList<>(); + colors.forEach(c -> colorList.add(c.asText())); + assertEquals(List.of("Green"), colorList); + } + + try (CloseableIterator iter = flatCollection.find(queryById("5"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + JsonNode colors = json.get("props").get("colors"); + assertEquals(0, colors.size()); + } + } + @Test @DisplayName("Should handle edge cases: empty map, null map, non-existent keys") void testBulkUpdateEdgeCases() throws Exception { diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java index 2e3e023d..cbbcbfa3 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java @@ -28,7 +28,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -98,6 +97,7 @@ public class FlatPostgresCollection extends PostgresCollection { "Write operations are not supported for flat collections yet!"; private static final String MISSING_COLUMN_STRATEGY_CONFIG = "missingColumnStrategy"; private static final String DEFAULT_PRIMARY_KEY_COLUMN = "key"; + private static final String SHAPE_KEY_DELIMITER = "\u0001"; private static final Map UPDATE_PARSER_MAP = Map.ofEntries( @@ -883,12 +883,13 @@ public BulkUpdateResult bulkUpdate( String quotedPkColumn = PostgresUtils.wrapFieldNamesWithDoubleQuotes(getPKForTable(tableName)); long batchUpdateTimestamp = System.currentTimeMillis(); - // Group keys by their "SQL shape" (same update operations) - Map keyGroups = groupKeysByUpdateShape(updates, tableName); - int totalUpdated = 0; try (Connection connection = client.getPooledConnection()) { + // Group keys by their "SQL shape" (same SET-clause fragments AND param count) + Map keyGroups = + groupKeysByUpdateShape(connection, updates, tableName); + // Execute one multi-row UPDATE per group (or fallback to single-key if group size = 1) for (Map.Entry entry : keyGroups.entrySet()) { try { @@ -955,11 +956,14 @@ private boolean executeKeyUpdate( } /** - * Groups keys that have identical update operations together. Keys with the same "shape" can be - * updated in a single multi-row statement. + * Groups keys that produce identical SET-clause SQL together. Two keys share a shape only if + * {@code buildSetClauseFragments} renders the exact same fragment list and the same number of + * bind parameters — required because {@code executeBatchUpdate} reuses one PreparedStatement per + * group. Operators whose generated SQL or placeholder count varies with the input value (e.g., + * nested-JSONB REMOVE_ALL_FROM_LIST emitting 1+N placeholders) will land in distinct groups. */ private Map groupKeysByUpdateShape( - Map> updates, String tableName) { + Connection connection, Map> updates, String tableName) { Map groups = new LinkedHashMap<>(); @@ -975,11 +979,20 @@ private Map groupKeysByUpdateShape( updateValidator.validate(keyUpdates); Map resolvedColumns = resolvePathsToColumns(keyUpdates, tableName); - String shapeKey = computeUpdateShapeKey(keyUpdates, resolvedColumns); + List setFragments = new ArrayList<>(); + List params = new ArrayList<>(); + boolean hasUpdates = + buildSetClauseFragments( + connection, keyUpdates, tableName, resolvedColumns, setFragments, params); + if (!hasUpdates) { + continue; + } + + String shapeKey = computeUpdateShapeKey(setFragments, params.size()); groups - .computeIfAbsent(shapeKey, k -> new KeyUpdateGroup(resolvedColumns)) - .addKeyWithUpdates(key, keyUpdates); + .computeIfAbsent(shapeKey, k -> new KeyUpdateGroup(resolvedColumns, setFragments)) + .addKeyWithParams(key, params); } catch (Exception e) { LOGGER.warn("Failed to group key {}: {}", key, e.getMessage()); @@ -989,30 +1002,14 @@ private Map groupKeysByUpdateShape( return groups; } - private String computeUpdateShapeKey( - Collection updates, Map resolvedColumns) { - - List sorted = new ArrayList<>(updates); - sorted.sort(Comparator.comparing(u -> u.getSubDocument().getPath())); - - StringBuilder sb = new StringBuilder(); - for (SubDocumentUpdate update : sorted) { - String path = update.getSubDocument().getPath(); - String column = resolvedColumns.get(path); - sb.append(column) - .append(":") - .append(update.getOperator()) - .append(":") - .append(path) - .append(";"); - } - - return sb.toString(); + private String computeUpdateShapeKey(List setFragments, int paramCount) { + return paramCount + "|" + String.join(SHAPE_KEY_DELIMITER, setFragments); } /** - * Executes a batch UPDATE for all keys in the group using JDBC batching. All keys in the group - * share the same SQL structure, so we can use a single PreparedStatement. + * Executes a batch UPDATE for all keys in the group using JDBC batching. The group's setFragments + * and per-key params were rendered during grouping, so all keys here are guaranteed to share the + * same SQL and placeholder count. */ private int executeBatchUpdate( Connection connection, @@ -1023,25 +1020,12 @@ private int executeBatchUpdate( throws SQLException { List keys = keyGroup.getKeys(); - List> allKeyUpdates = keyGroup.getKeyUpdates(); - Map resolvedColumns = keyGroup.getResolvedColumns(); - - // Use the first key's updates to build the SQL template - Collection templateUpdates = allKeyUpdates.get(0); - List setFragments = new ArrayList<>(); - List templateParams = new ArrayList<>(); - - boolean hasUpdates = - buildSetClauseFragments( - connection, templateUpdates, tableName, resolvedColumns, setFragments, templateParams); + List> allKeyParams = keyGroup.getKeyParams(); - if (!hasUpdates) { - return 0; - } + List setFragments = new ArrayList<>(keyGroup.getSetFragments()); + List timestampParam = new ArrayList<>(); + appendLastUpdatedTimestamp(setFragments, timestampParam, tableName, epochMillis); - appendLastUpdatedTimestamp(setFragments, templateParams, tableName, epochMillis); - - // Build UPDATE SQL (same for all keys in this group) String sql = String.format( "UPDATE %s SET %s WHERE %s = ?", @@ -1049,36 +1033,16 @@ private int executeBatchUpdate( LOGGER.debug("Executing batch update SQL: {} for {} keys", sql, keys.size()); - // Use JDBC batching to execute all updates in one round-trip try (PreparedStatement ps = connection.prepareStatement(sql)) { for (int i = 0; i < keys.size(); i++) { - Key key = keys.get(i); - Collection keyUpdates = allKeyUpdates.get(i); - - // Build parameters for this specific key - List keySetFragments = new ArrayList<>(); - List keyParams = new ArrayList<>(); - buildSetClauseFragments( - connection, keyUpdates, tableName, resolvedColumns, keySetFragments, keyParams); - - // Add timestamp parameter - if (lastUpdatedTsColumn != null) { - Optional colMeta = - schemaRegistry.getColumnOrRefresh(tableName, lastUpdatedTsColumn); - if (colMeta.isPresent()) { - Object timestampValue = - convertTimestampForType(epochMillis, colMeta.get().getPostgresType()); - keyParams.add(timestampValue); - } - } - - // Bind parameters for this key int idx = 1; - for (Object param : keyParams) { + for (Object param : allKeyParams.get(i)) { ps.setObject(idx++, param); } - ps.setObject(idx, key.toString()); // WHERE clause parameter - + for (Object param : timestampParam) { + ps.setObject(idx++, param); + } + ps.setObject(idx, keys.get(i).toString()); // WHERE clause parameter ps.addBatch(); } @@ -1098,31 +1062,41 @@ private int executeBatchUpdate( } } - /** Holds a group of keys that share the same update shape. */ + /** + * Holds a group of keys that share the same SET-clause SQL. {@code setFragments} is rendered once + * during grouping; {@code keyParams} stores the bind values for each key in lockstep with {@code + * keys}. + */ private static class KeyUpdateGroup { private final Map resolvedColumns; + private final List setFragments; private final List keys = new ArrayList<>(); - private final List> keyUpdates = new ArrayList<>(); + private final List> keyParams = new ArrayList<>(); - KeyUpdateGroup(Map resolvedColumns) { + KeyUpdateGroup(Map resolvedColumns, List setFragments) { this.resolvedColumns = resolvedColumns; + this.setFragments = setFragments; } - void addKeyWithUpdates(Key key, Collection updates) { + void addKeyWithParams(Key key, List params) { keys.add(key); - keyUpdates.add(updates); + keyParams.add(params); } Map getResolvedColumns() { return resolvedColumns; } + List getSetFragments() { + return setFragments; + } + List getKeys() { return keys; } - List> getKeyUpdates() { - return keyUpdates; + List> getKeyParams() { + return keyParams; } }