From bc1ba3f20798ff9058569f5326340a8bff53f67b Mon Sep 17 00:00:00 2001 From: Kavya Parashar Date: Thu, 16 Apr 2026 08:52:10 +0530 Subject: [PATCH 1/6] fix: Azure provisioning, generator robustness, and missing masking functions - Makefile.shared: replace all bare `python` with `python3` for macOS compat - setup_demo.py: fix Azure workspace creation (ARM REST API instead of SDK), fix warehouse_type enum, add max_num_clusters, remove invalid managed_resource_group_id - generate_abac.py: pick most complete HCL/SQL block from LLM response (not first), strip SQL-style comments from HCL output, clean up stray commas after policy removal, remove orphaned tag assignments when policies are dropped for missing functions - financial_services.yaml: add 6 missing core masking functions (mask_redact, mask_email, mask_phone, mask_date_to_year, mask_credit_card_full, filter_aml_compliance) Co-authored-by: Isaac --- shared/Makefile.shared | 38 ++++---- shared/examples/aus_bank_demo/setup_demo.py | 102 ++++++++++++++++---- shared/generate_abac.py | 99 +++++++++++++++++-- shared/industries/financial_services.yaml | 40 ++++++++ 4 files changed, 230 insertions(+), 49 deletions(-) diff --git a/shared/Makefile.shared b/shared/Makefile.shared index 6146ce8..7917c24 100644 --- a/shared/Makefile.shared +++ b/shared/Makefile.shared @@ -191,9 +191,9 @@ generate: _bootstrap _guard-workspace-target ## Run generate_abac.py in the sele echo " then use 'make generate SPACE=\"$(SPACE)\"' to add or update individual spaces."; \ exit 1; \ fi; \ - cd "$(ENV_DIR)" && python "$(SHARED_ROOT)/generate_abac.py" $(GENERATE_ARGS) $(if $(MODE),--mode $(MODE),) $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),) --space "$(SPACE)"; \ + cd "$(ENV_DIR)" && python3 "$(SHARED_ROOT)/generate_abac.py" $(GENERATE_ARGS) $(if $(MODE),--mode $(MODE),) $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),) --space "$(SPACE)"; \ else \ - cd "$(ENV_DIR)" && python "$(SHARED_ROOT)/generate_abac.py" $(GENERATE_ARGS) $(if $(MODE),--mode $(MODE),) $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); \ + cd "$(ENV_DIR)" && python3 "$(SHARED_ROOT)/generate_abac.py" $(GENERATE_ARGS) $(if $(MODE),--mode $(MODE),) $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); \ fi audit-schema: _bootstrap _guard-workspace-target ## Report untagged sensitive columns and stale tag assignments in managed tables @@ -207,18 +207,18 @@ generate-delta: _bootstrap _guard-workspace-target ## Detect schema drift and ge validate-generated: _bootstrap _guard-workspace-target ## Validate generated/ files in the selected workspace environment @echo "=== Validate (generated/) [$(ENV)] ===" @cd "$(ENV_DIR)" && if [ -f generated/masking_functions.sql ]; then \ - python "$(SHARED_ROOT)/validate_abac.py" generated/abac.auto.tfvars generated/masking_functions.sql $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); \ + python3 "$(SHARED_ROOT)/validate_abac.py" generated/abac.auto.tfvars generated/masking_functions.sql $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); \ else \ - python "$(SHARED_ROOT)/validate_abac.py" generated/abac.auto.tfvars $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); \ + python3 "$(SHARED_ROOT)/validate_abac.py" generated/abac.auto.tfvars $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); \ fi validate: _bootstrap ## Validate split config in the selected environment @echo "=== Validate ($(ENV)) ===" @if [ "$(ENV)" = "$(ACCOUNT_ENV)" ]; then \ - cd "$(ACCOUNT_ENV_DIR)" && if [ -f abac.auto.tfvars ]; then python "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); fi; \ + cd "$(ACCOUNT_ENV_DIR)" && if [ -f abac.auto.tfvars ]; then python3 "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); fi; \ else \ - cd "$(ENV_DIR)/$(DATA_ACCESS_SUBDIR)" && if [ -f abac.auto.tfvars ]; then python "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars masking_functions.sql $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); fi; \ - cd "$(ENV_DIR)" && python "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); \ + cd "$(ENV_DIR)/$(DATA_ACCESS_SUBDIR)" && if [ -f abac.auto.tfvars ]; then python3 "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars masking_functions.sql $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); fi; \ + cd "$(ENV_DIR)" && python3 "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars $(if $(COUNTRY),--country $(COUNTRY),) $(if $(INDUSTRY),--industry $(INDUSTRY),); \ fi promote: _bootstrap _guard-workspace-target ## Split generated/ into layers (same-env), or cross-env: make promote SOURCE_ENV=dev DEST_ENV=prod DEST_CATALOG_MAP="dev_cat=prod_cat" @@ -264,9 +264,9 @@ print(', '.join(cats) if cats else '(none detected)') \ $$map_flags; \ cd "$$dest_env_dir" && python3 "$(SHARED_ROOT)/scripts/split_abac_config.py" generated/abac.auto.tfvars ../$(ACCOUNT_ENV)/abac.auto.tfvars $(DATA_ACCESS_SUBDIR)/abac.auto.tfvars abac.auto.tfvars; \ cp "$$dest_env_dir/generated/masking_functions.sql" "$$dest_env_dir/$(DATA_ACCESS_SUBDIR)/masking_functions.sql"; \ - cd "$(ACCOUNT_ENV_DIR)" && if [ -f abac.auto.tfvars ]; then python "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars; fi; \ - cd "$$dest_env_dir/$(DATA_ACCESS_SUBDIR)" && if [ -f abac.auto.tfvars ]; then python "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars masking_functions.sql; fi; \ - cd "$$dest_env_dir" && python "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars; \ + cd "$(ACCOUNT_ENV_DIR)" && if [ -f abac.auto.tfvars ]; then python3 "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars; fi; \ + cd "$$dest_env_dir/$(DATA_ACCESS_SUBDIR)" && if [ -f abac.auto.tfvars ]; then python3 "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars masking_functions.sql; fi; \ + cd "$$dest_env_dir" && python3 "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars; \ echo ""; \ echo "=== Promote complete: $(SOURCE_ENV) -> $(DEST_ENV) ==="; \ echo " Next: edit $$dest_env_dir/auth.auto.tfvars ($(DEST_ENV) workspace credentials)"; \ @@ -281,17 +281,17 @@ print(', '.join(cats) if cats else '(none detected)') \ fi; \ echo "=== Split generated/ to account + data_access + workspace ($(ENV)) ==="; \ if [ -f "$(ENV_DIR)/generated/masking_functions.sql" ]; then \ - cd "$(ENV_DIR)" && python "$(SHARED_ROOT)/validate_abac.py" generated/abac.auto.tfvars generated/masking_functions.sql; \ + cd "$(ENV_DIR)" && python3 "$(SHARED_ROOT)/validate_abac.py" generated/abac.auto.tfvars generated/masking_functions.sql; \ else \ - cd "$(ENV_DIR)" && python "$(SHARED_ROOT)/validate_abac.py" generated/abac.auto.tfvars; \ + cd "$(ENV_DIR)" && python3 "$(SHARED_ROOT)/validate_abac.py" generated/abac.auto.tfvars; \ fi; \ cd "$(ENV_DIR)" && python3 "$(SHARED_ROOT)/scripts/split_abac_config.py" generated/abac.auto.tfvars ../$(ACCOUNT_ENV)/abac.auto.tfvars $(DATA_ACCESS_SUBDIR)/abac.auto.tfvars abac.auto.tfvars; \ if [ -f "$(ENV_DIR)/generated/masking_functions.sql" ]; then \ cp "$(ENV_DIR)/generated/masking_functions.sql" "$(ENV_DIR)/$(DATA_ACCESS_SUBDIR)/masking_functions.sql"; \ fi; \ - cd "$(ACCOUNT_ENV_DIR)" && if [ -f abac.auto.tfvars ]; then python "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars; fi; \ - cd "$(ENV_DIR)/$(DATA_ACCESS_SUBDIR)" && if [ -f abac.auto.tfvars ]; then python "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars masking_functions.sql; fi; \ - cd "$(ENV_DIR)" && python "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars; \ + cd "$(ACCOUNT_ENV_DIR)" && if [ -f abac.auto.tfvars ]; then python3 "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars; fi; \ + cd "$(ENV_DIR)/$(DATA_ACCESS_SUBDIR)" && if [ -f abac.auto.tfvars ]; then python3 "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars masking_functions.sql; fi; \ + cd "$(ENV_DIR)" && python3 "$(SHARED_ROOT)/validate_abac.py" abac.auto.tfvars; \ fi _plan-layer: @@ -722,7 +722,7 @@ integration-test: ## Run the full end-to-end integration test (setup → generat @echo "========================================================" @echo "" @echo "── Step 1: Create test catalogs (dev + prod) ──────────" - cd "$(CLOUD_ROOT)" && python $(SHARED_ROOT)/scripts/setup_test_data.py \ + cd "$(CLOUD_ROOT)" && python3 $(SHARED_ROOT)/scripts/setup_test_data.py \ --auth-file "$(ITEST_AUTH)" $(ITEST_WH_FLAG) --prod @echo "" @echo "── Step 2: Env scaffolding ─────────────────────────────" @@ -738,7 +738,7 @@ integration-test: ## Run the full end-to-end integration test (setup → generat $(MAKE) --no-print-directory apply ENV="$(ENV)" @echo "" @echo "── Step 6: Verify dev (row counts + ABAC governance) ───" - cd "$(CLOUD_ROOT)" && python $(SHARED_ROOT)/scripts/setup_test_data.py \ + cd "$(CLOUD_ROOT)" && python3 $(SHARED_ROOT)/scripts/setup_test_data.py \ --auth-file "$(ITEST_AUTH)" $(ITEST_WH_FLAG) --verify @echo "" @echo "── Step 7: Per-space generation test (Finance only) ────" @@ -753,12 +753,12 @@ integration-test: ## Run the full end-to-end integration test (setup → generat $(MAKE) --no-print-directory apply ENV="$(ITEST_DEST_ENV)" @echo "" @echo "── Step 10: Verify prod (row counts + ABAC governance) ─" - cd "$(CLOUD_ROOT)" && python $(SHARED_ROOT)/scripts/setup_test_data.py \ + cd "$(CLOUD_ROOT)" && python3 $(SHARED_ROOT)/scripts/setup_test_data.py \ --auth-file "$(ITEST_AUTH)" $(ITEST_WH_FLAG) --verify-prod @echo "" @if [ -z "$(KEEP_DATA)" ]; then \ echo "── Step 11: Teardown ────────────────────────────────────"; \ - cd "$(CLOUD_ROOT)" && python $(SHARED_ROOT)/scripts/setup_test_data.py \ + cd "$(CLOUD_ROOT)" && python3 $(SHARED_ROOT)/scripts/setup_test_data.py \ --auth-file "$(ITEST_AUTH)" $(ITEST_WH_FLAG) --teardown --teardown-prod; \ $(MAKE) --no-print-directory destroy ENV="$(ITEST_DEST_ENV)"; \ $(MAKE) --no-print-directory destroy ENV="$(ENV)"; \ diff --git a/shared/examples/aus_bank_demo/setup_demo.py b/shared/examples/aus_bank_demo/setup_demo.py index c0bf077..7547e8f 100644 --- a/shared/examples/aus_bank_demo/setup_demo.py +++ b/shared/examples/aus_bank_demo/setup_demo.py @@ -447,10 +447,12 @@ def _create_tables_via_sdk(dev_state: dict) -> str: wh_id = wh.id break if not wh_id: + from databricks.sdk.service.sql import CreateWarehouseRequestWarehouseType wh = w.warehouses.create( name="Demo Warehouse", cluster_size="2X-Small", - warehouse_type="PRO", + warehouse_type=CreateWarehouseRequestWarehouseType.PRO, + max_num_clusters=1, auto_stop_mins=15, enable_serverless_compute=True, ).result() @@ -544,31 +546,91 @@ def _create_prod_workspace(cfg: dict, cloud: str, metastore_id: str, dev_state: else: ws_kwargs = { "location": region, - "managed_resource_group_id": ( - f"/subscriptions/{cfg.get('AZURE_SUBSCRIPTION_ID', '')}" - f"/resourceGroups/{prod_ws_name}-managed" - ), } print(f" Creating workspace: {prod_ws_name} in {region}...") - try: - from databricks.sdk.service.provisioning import ( - CustomerFacingComputeMode, - PricingTier, + if cloud == "azure": + # Azure workspaces must be created via ARM REST API, not the account SDK + import json + import urllib.request + import urllib.error + from azure.identity import ClientSecretCredential + + subscription_id = cfg.get("AZURE_SUBSCRIPTION_ID", "") + resource_group = cfg.get("AZURE_RESOURCE_GROUP", "") + arm_cred = ClientSecretCredential( + tenant_id=cfg.get("AZURE_TENANT_ID", ""), + client_id=cfg.get("AZURE_CLIENT_ID", ""), + client_secret=cfg.get("AZURE_CLIENT_SECRET", ""), ) - ws = a.workspaces.create_and_wait( - workspace_name=prod_ws_name, - pricing_tier=PricingTier.ENTERPRISE, - compute_mode=CustomerFacingComputeMode.SERVERLESS, - **ws_kwargs, + arm_token = arm_cred.get_token("https://management.azure.com/.default").token + arm_api_version = "2025-10-01-preview" + + arm_url = ( + f"https://management.azure.com/subscriptions/{subscription_id}" + f"/resourceGroups/{resource_group}" + f"/providers/Microsoft.Databricks/workspaces/{prod_ws_name}" + f"?api-version={arm_api_version}" ) - except (ImportError, TypeError): - # Fallback for older SDK versions without compute_mode - ws = a.workspaces.create(workspace_name=prod_ws_name, **ws_kwargs).result() + arm_body = json.dumps({ + "location": region, + "sku": {"name": "premium"}, + "properties": {"computeMode": "Serverless"}, + "tags": {"ManagedBy": "setup_demo"}, + }).encode() + + req = urllib.request.Request(arm_url, data=arm_body, method="PUT", headers={ + "Authorization": f"Bearer {arm_token}", + "Content-Type": "application/json", + }) + try: + with urllib.request.urlopen(req) as resp: + data = json.loads(resp.read()) + except urllib.error.HTTPError as e: + detail = e.read().decode(errors="replace") + raise RuntimeError(f"ARM PUT {e.code}: {detail}") from e + + # Poll until Succeeded + import time as _time + deadline = _time.time() + 600 + while _time.time() < deadline: + _time.sleep(15) + get_req = urllib.request.Request(arm_url, headers={"Authorization": f"Bearer {arm_token}"}) + with urllib.request.urlopen(get_req) as resp: + data = json.loads(resp.read()) + props = data.get("properties", {}) + prov_state = props.get("provisioningState", "Unknown") + elapsed = int(_time.time() - (deadline - 600)) + print(f" [{elapsed}s] {prov_state}") + if prov_state == "Succeeded": + break + if prov_state in ("Failed", "Canceled"): + raise RuntimeError(f"Workspace creation {prov_state}") + else: + raise TimeoutError("Workspace did not reach Succeeded within 10 minutes") + + props = data.get("properties", {}) + prod_ws_id = str(props["workspaceId"]) + ws_url = props["workspaceUrl"] + prod_host = f"https://{ws_url}" if not ws_url.startswith("https://") else ws_url + else: + try: + from databricks.sdk.service.provisioning import ( + CustomerFacingComputeMode, + PricingTier, + ) + ws = a.workspaces.create_and_wait( + workspace_name=prod_ws_name, + pricing_tier=PricingTier.ENTERPRISE, + compute_mode=CustomerFacingComputeMode.SERVERLESS, + **ws_kwargs, + ) + except (ImportError, TypeError): + # Fallback for older SDK versions without compute_mode + ws = a.workspaces.create(workspace_name=prod_ws_name, **ws_kwargs).result() - prod_host = (f"https://{ws.deployment_name}.cloud.databricks.com" - if cloud == "aws" else (ws.workspace_url or "")) - prod_ws_id = str(ws.workspace_id) + prod_host = f"https://{ws.deployment_name}.cloud.databricks.com" + prod_ws_id = str(ws.workspace_id) print(f" {_green('✓')} Prod workspace created: {prod_host}") # Assign shared metastore diff --git a/shared/generate_abac.py b/shared/generate_abac.py index 2ac70a9..fc26dfd 100644 --- a/shared/generate_abac.py +++ b/shared/generate_abac.py @@ -1105,18 +1105,34 @@ def _extract_hcl_fallback(text: str) -> str | None: candidate = "\n".join(candidate_lines).strip() return candidate if _looks_like_hcl(candidate) else None + hcl_candidates: list[str] = [] + sql_candidates: list[str] = [] for lang, content in blocks: content = content.strip() lang_lower = lang.lower() - if lang_lower == "sql" and sql_block is None: - sql_block = content - elif lang_lower in ("hcl", "terraform") and hcl_block is None: - hcl_block = content - elif not lang and sql_block is None and "CREATE" in content.upper() and "FUNCTION" in content.upper(): - sql_block = content - elif not lang and hcl_block is None and _looks_like_hcl(content): - hcl_block = content + if lang_lower == "sql": + sql_candidates.append(content) + elif lang_lower in ("hcl", "terraform"): + hcl_candidates.append(content) + elif not lang and "CREATE" in content.upper() and "FUNCTION" in content.upper(): + sql_candidates.append(content) + elif not lang and _looks_like_hcl(content): + hcl_candidates.append(content) + + # Pick the largest SQL block (most CREATE FUNCTION statements) + if sql_candidates: + sql_block = max(sql_candidates, key=lambda c: (c.upper().count("CREATE"), len(c))) + + # Pick the most complete HCL block — the one with the most top-level keys + # (groups, tag_policies, tag_assignments, fgac_policies, genie_space_configs). + # The LLM often emits partial blocks before the final complete one. + if hcl_candidates: + def _key_count(c: str) -> int: + keys = ("groups", "tag_policies", "tag_assignments", "fgac_policies", + "genie_space_configs", "genie_space_title", "group_members") + return sum(1 for k in keys if re.search(rf"(?m)^\s*{k}\s*=", c)) + hcl_block = max(hcl_candidates, key=lambda c: (_key_count(c), len(c))) if hcl_block is None: hcl_block = _extract_hcl_fallback(response_text) @@ -1150,6 +1166,9 @@ def sanitize_tfvars_hcl(hcl_block: str) -> str: continue if re.match(r"^\s*#\s*Databricks\s+Authentication\b", line, re.IGNORECASE): continue + # Strip SQL-style comments that the LLM sometimes emits into HCL output + if re.match(r"^\s*--", line): + continue m = re.match(r"^\s*([A-Za-z0-9_]+)\s*=", line) if m and m.group(1) in TFVARS_STRIP_KEYS: @@ -1529,7 +1548,7 @@ def fix_hcl_syntax(tfvars_path: Path) -> int: # Look ahead: find the next non-blank, non-comment line j = i + 1 while j < len(lines) and ( - lines[j].strip() == '' or lines[j].lstrip().startswith('#') + lines[j].strip() == '' or lines[j].lstrip().startswith('#') or lines[j].lstrip().startswith('--') ): j += 1 if j < len(lines): @@ -1588,6 +1607,19 @@ def _object_vals_to_strings(m: re.Match) -> str: repairs += 1 text = fixed4 + # ------------------------------------------------------------------ + # Fix 5: remove stray commas left by autofix block removals. + # - Bare comma lines + # - Consecutive commas + # - Trailing comma before ] + # ------------------------------------------------------------------ + fixed5 = re.sub(r'^\s*,\s*$', '', text, flags=re.MULTILINE) + fixed5 = re.sub(r',(\s*,)+', ',', fixed5) + fixed5 = re.sub(r',(\s*\])', r'\1', fixed5) + if fixed5 != text: + repairs += 1 + text = fixed5 + if text != original: tfvars_path.write_text(text) print(f" [AUTOFIX] Repaired {repairs} HCL syntax issue(s)") @@ -2520,7 +2552,10 @@ def _remove_block(txt: str, block_name: str) -> tuple[str, bool]: ) if removed or assignments_removed: - # Clean up double-blank lines left by removal + # Clean up stray commas and double-blank lines left by removal + text = re.sub(r'^\s*,\s*$', '', text, flags=re.MULTILINE) + text = re.sub(r',(\s*,)+', ',', text) + text = re.sub(r',(\s*\])', r'\1', text) text = re.sub(r"\n{3,}", "\n\n", text) tfvars_path.write_text(text) @@ -3836,7 +3871,51 @@ def autofix_invalid_function_refs(tfvars_path: Path, sql_path: Path | None = Non if not fixes: return 0 + # Clean up stray commas left behind by block removals + rewritten = re.sub(r'^\s*,\s*$', '', rewritten, flags=re.MULTILINE) # bare comma lines + rewritten = re.sub(r',(\s*,)+', ',', rewritten) # consecutive commas + rewritten = re.sub(r',(\s*\])', r'\1', rewritten) # trailing comma before ] + text = text[:sec_start] + rewritten + text[sec_end:] + + # --- Remove orphaned tag_assignments left by removed policies ---------- + # Collect tag key/value pairs that the removed policies covered. + removed_tag_pairs: set[tuple[str, str]] = set() + for pname in removals: + # Find the original policy in the parsed list and extract its match condition + for p in policies: + if p.get("name") == pname: + for field in ("match_condition", "when_condition"): + cond = p.get(field, "") or "" + for m in re.finditer(r"hasTagValue\(\s*'([^']+)'\s*,\s*'([^']+)'\s*\)", cond): + removed_tag_pairs.add((m.group(1), m.group(2))) + + if removed_tag_pairs: + # Check which of these pairs are still covered by a surviving policy + surviving_pairs: set[tuple[str, str]] = set() + for p in policies: + if p.get("name") in removals: + continue + for field in ("match_condition", "when_condition"): + cond = p.get(field, "") or "" + for m in re.finditer(r"hasTagValue\(\s*'([^']+)'\s*,\s*'([^']+)'\s*\)", cond): + surviving_pairs.add((m.group(1), m.group(2))) + + orphaned_pairs = removed_tag_pairs - surviving_pairs + if orphaned_pairs: + for tag_key, tag_value in orphaned_pairs: + # Remove matching tag_assignment lines + pattern = re.compile( + r'[ \t]*\{[^}]*tag_key\s*=\s*"' + re.escape(tag_key) + + r'"[^}]*tag_value\s*=\s*"' + re.escape(tag_value) + + r'"[^}]*\}\s*,?\s*\n?', + ) + new_text, n = pattern.subn('', text) + if n: + text = new_text + print(f" [AUTOFIX] Removed {n} orphaned tag_assignment(s) for " + f"{tag_key} = '{tag_value}' (policy removed)") + tfvars_path.write_text(text) return fixes diff --git a/shared/industries/financial_services.yaml b/shared/industries/financial_services.yaml index 126043a..1244dab 100644 --- a/shared/industries/financial_services.yaml +++ b/shared/industries/financial_services.yaml @@ -144,6 +144,46 @@ masking_functions: ELSE '[REDACTED]' END + - name: mask_redact + signature: "mask_redact(val STRING) RETURNS STRING" + comment: "Full redaction — returns [REDACTED]" + body: "CASE WHEN val IS NULL THEN NULL ELSE '[REDACTED]' END" + + - name: mask_email + signature: "mask_email(email STRING) RETURNS STRING" + comment: "Email — mask local part, keep domain" + body: | + CASE + WHEN email IS NULL THEN NULL + WHEN INSTR(email, '@') > 1 THEN CONCAT(LEFT(email, 1), '****@', SUBSTRING(email, INSTR(email, '@') + 1)) + ELSE '[REDACTED]' + END + + - name: mask_phone + signature: "mask_phone(phone STRING) RETURNS STRING" + comment: "Phone number — show last 4 digits only" + body: | + CASE + WHEN phone IS NULL THEN NULL + WHEN LENGTH(REGEXP_REPLACE(phone, '[^0-9]', '')) >= 4 THEN CONCAT('***-***-', RIGHT(REGEXP_REPLACE(phone, '[^0-9]', ''), 4)) + ELSE '[REDACTED]' + END + + - name: mask_date_to_year + signature: "mask_date_to_year(dt DATE) RETURNS DATE" + comment: "Date — truncate to year" + body: "CASE WHEN dt IS NULL THEN NULL ELSE DATE_TRUNC('YEAR', dt) END" + + - name: mask_credit_card_full + signature: "mask_credit_card_full(card STRING) RETURNS STRING" + comment: "Full card redaction — PCI DSS" + body: "CASE WHEN card IS NULL THEN NULL ELSE '[REDACTED]' END" + + - name: filter_aml_compliance + signature: "filter_aml_compliance() RETURNS BOOLEAN" + comment: "AML row filter — only compliance and fraud teams see all rows" + body: "is_account_group_member('fraud_team') OR is_account_group_member('Compliance_Officer') OR is_account_group_member('compliance_officer')" + group_templates: fraud_team: description: "Full access to all financial data for fraud investigation and AML compliance" From e0cb1ff42b204d054e931c47bac632dc8a1660f0 Mon Sep 17 00:00:00 2001 From: Kavya Parashar Date: Mon, 20 Apr 2026 13:01:54 +0530 Subject: [PATCH 2/6] =?UTF-8?q?fix:=20address=20PR=20#18=20review=20?= =?UTF-8?q?=E2=80=94=20GA=20API=20version,=20polling=20resilience,=20dedup?= =?UTF-8?q?=20cleanup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Switch ARM API version from 2025-10-01-preview to 2024-05-01 (GA) in both setup_demo.py and azure_provider.py - Add try/except with retry on transient 5xx errors in workspace polling loops (setup_demo.py and azure_provider.py) - Extract _cleanup_stray_commas() helper in generate_abac.py, replacing 7 duplicate regex sites; also fix pre-existing bug where valid HCL trailing commas after } were incorrectly stripped - Remove generic masking functions (mask_redact, mask_email, mask_phone, mask_date_to_year, mask_credit_card_full) from financial_services.yaml to avoid definition drift with the base system - Normalize filter_aml_compliance to use lowercase compliance_officer Co-authored-by: Isaac --- shared/examples/aus_bank_demo/setup_demo.py | 12 ++++-- shared/generate_abac.py | 42 +++++++++---------- shared/industries/financial_services.yaml | 37 +--------------- .../scripts/cloud_providers/azure_provider.py | 12 ++++-- 4 files changed, 40 insertions(+), 63 deletions(-) diff --git a/shared/examples/aus_bank_demo/setup_demo.py b/shared/examples/aus_bank_demo/setup_demo.py index 7547e8f..8410e7b 100644 --- a/shared/examples/aus_bank_demo/setup_demo.py +++ b/shared/examples/aus_bank_demo/setup_demo.py @@ -564,7 +564,7 @@ def _create_prod_workspace(cfg: dict, cloud: str, metastore_id: str, dev_state: client_secret=cfg.get("AZURE_CLIENT_SECRET", ""), ) arm_token = arm_cred.get_token("https://management.azure.com/.default").token - arm_api_version = "2025-10-01-preview" + arm_api_version = "2024-05-01" arm_url = ( f"https://management.azure.com/subscriptions/{subscription_id}" @@ -596,8 +596,14 @@ def _create_prod_workspace(cfg: dict, cloud: str, metastore_id: str, dev_state: while _time.time() < deadline: _time.sleep(15) get_req = urllib.request.Request(arm_url, headers={"Authorization": f"Bearer {arm_token}"}) - with urllib.request.urlopen(get_req) as resp: - data = json.loads(resp.read()) + try: + with urllib.request.urlopen(get_req) as resp: + data = json.loads(resp.read()) + except urllib.error.HTTPError as e: + if e.code >= 500: + print(f" Transient error ({e.code}), retrying...") + continue + raise props = data.get("properties", {}) prov_state = props.get("provisioningState", "Unknown") elapsed = int(_time.time() - (deadline - 600)) diff --git a/shared/generate_abac.py b/shared/generate_abac.py index fc26dfd..324e564 100644 --- a/shared/generate_abac.py +++ b/shared/generate_abac.py @@ -1513,6 +1513,20 @@ def call_with_retries(call_fn, prompt: str, model: str, max_retries: int) -> str raise RuntimeError(f"All {max_retries} attempts failed. Last error: {last_error}") +def _cleanup_stray_commas(text: str) -> str: + """Remove stray commas left behind by block removals in HCL text. + + Handles bare comma lines, consecutive commas, and trailing commas before ] + (but preserves valid trailing commas after ``}`` or quoted strings). + """ + text = re.sub(r'^\s*,\s*$', '', text, flags=re.MULTILINE) + text = re.sub(r',(\s*,)+', ',', text) + # Only strip a comma before ] when preceded by whitespace (bare/stray comma), + # not when preceded by } or " (valid HCL trailing comma). + text = re.sub(r'(? int: """Repair common HCL syntax errors introduced by the LLM. @@ -1609,13 +1623,8 @@ def _object_vals_to_strings(m: re.Match) -> str: # ------------------------------------------------------------------ # Fix 5: remove stray commas left by autofix block removals. - # - Bare comma lines - # - Consecutive commas - # - Trailing comma before ] # ------------------------------------------------------------------ - fixed5 = re.sub(r'^\s*,\s*$', '', text, flags=re.MULTILINE) - fixed5 = re.sub(r',(\s*,)+', ',', fixed5) - fixed5 = re.sub(r',(\s*\])', r'\1', fixed5) + fixed5 = _cleanup_stray_commas(text) if fixed5 != text: repairs += 1 text = fixed5 @@ -2553,9 +2562,7 @@ def _remove_block(txt: str, block_name: str) -> tuple[str, bool]: if removed or assignments_removed: # Clean up stray commas and double-blank lines left by removal - text = re.sub(r'^\s*,\s*$', '', text, flags=re.MULTILINE) - text = re.sub(r',(\s*,)+', ',', text) - text = re.sub(r',(\s*\])', r'\1', text) + text = _cleanup_stray_commas(text) text = re.sub(r"\n{3,}", "\n\n", text) tfvars_path.write_text(text) @@ -3872,9 +3879,7 @@ def autofix_invalid_function_refs(tfvars_path: Path, sql_path: Path | None = Non return 0 # Clean up stray commas left behind by block removals - rewritten = re.sub(r'^\s*,\s*$', '', rewritten, flags=re.MULTILINE) # bare comma lines - rewritten = re.sub(r',(\s*,)+', ',', rewritten) # consecutive commas - rewritten = re.sub(r',(\s*\])', r'\1', rewritten) # trailing comma before ] + rewritten = _cleanup_stray_commas(rewritten) text = text[:sec_start] + rewritten + text[sec_end:] @@ -4482,9 +4487,7 @@ def autofix_duplicate_column_masks(tfvars_path: Path) -> int: print(f" Removed duplicate mask policy '{name}' (generic function on column already covered by specific policy)") if removed: - # Clean up any leftover double commas or trailing commas before ] - text = re.sub(r',\s*,', ',', text) - text = re.sub(r',\s*\]', '\n ]', text) + text = _cleanup_stray_commas(text) tfvars_path.write_text(text) return removed @@ -4537,8 +4540,7 @@ def autofix_forbidden_conditions(tfvars_path: Path) -> int: removed += 1 if removed: - text = re.sub(r',\s*,', ',', text) - text = re.sub(r',\s*\]', '\n ]', text) + text = _cleanup_stray_commas(text) tfvars_path.write_text(text) return removed @@ -4622,8 +4624,7 @@ def autofix_invalid_condition_values(tfvars_path: Path) -> int: ) text = pattern.sub("", text, count=1) - text = re.sub(r',\s*,', ',', text) - text = re.sub(r',\s*\]', '\n ]', text) + text = _cleanup_stray_commas(text) tfvars_path.write_text(text) return len(bad_names) @@ -4697,8 +4698,7 @@ def autofix_malformed_conditions(tfvars_path: Path) -> int: ) text = pattern.sub("", text, count=1) - text = re.sub(r',\s*,', ',', text) - text = re.sub(r',\s*\]', '\n ]', text) + text = _cleanup_stray_commas(text) tfvars_path.write_text(text) return len(bad_names) diff --git a/shared/industries/financial_services.yaml b/shared/industries/financial_services.yaml index 1244dab..b04d386 100644 --- a/shared/industries/financial_services.yaml +++ b/shared/industries/financial_services.yaml @@ -144,45 +144,10 @@ masking_functions: ELSE '[REDACTED]' END - - name: mask_redact - signature: "mask_redact(val STRING) RETURNS STRING" - comment: "Full redaction — returns [REDACTED]" - body: "CASE WHEN val IS NULL THEN NULL ELSE '[REDACTED]' END" - - - name: mask_email - signature: "mask_email(email STRING) RETURNS STRING" - comment: "Email — mask local part, keep domain" - body: | - CASE - WHEN email IS NULL THEN NULL - WHEN INSTR(email, '@') > 1 THEN CONCAT(LEFT(email, 1), '****@', SUBSTRING(email, INSTR(email, '@') + 1)) - ELSE '[REDACTED]' - END - - - name: mask_phone - signature: "mask_phone(phone STRING) RETURNS STRING" - comment: "Phone number — show last 4 digits only" - body: | - CASE - WHEN phone IS NULL THEN NULL - WHEN LENGTH(REGEXP_REPLACE(phone, '[^0-9]', '')) >= 4 THEN CONCAT('***-***-', RIGHT(REGEXP_REPLACE(phone, '[^0-9]', ''), 4)) - ELSE '[REDACTED]' - END - - - name: mask_date_to_year - signature: "mask_date_to_year(dt DATE) RETURNS DATE" - comment: "Date — truncate to year" - body: "CASE WHEN dt IS NULL THEN NULL ELSE DATE_TRUNC('YEAR', dt) END" - - - name: mask_credit_card_full - signature: "mask_credit_card_full(card STRING) RETURNS STRING" - comment: "Full card redaction — PCI DSS" - body: "CASE WHEN card IS NULL THEN NULL ELSE '[REDACTED]' END" - - name: filter_aml_compliance signature: "filter_aml_compliance() RETURNS BOOLEAN" comment: "AML row filter — only compliance and fraud teams see all rows" - body: "is_account_group_member('fraud_team') OR is_account_group_member('Compliance_Officer') OR is_account_group_member('compliance_officer')" + body: "is_account_group_member('fraud_team') OR is_account_group_member('compliance_officer')" group_templates: fraud_team: diff --git a/shared/scripts/cloud_providers/azure_provider.py b/shared/scripts/cloud_providers/azure_provider.py index 67b1de6..c00beb2 100644 --- a/shared/scripts/cloud_providers/azure_provider.py +++ b/shared/scripts/cloud_providers/azure_provider.py @@ -244,7 +244,7 @@ def workspace_create_kwargs(self, region: str) -> dict: return {"location": region} # ARM preview API version that supports computeMode=Serverless - _ARM_API_VERSION = "2025-10-01-preview" + _ARM_API_VERSION = "2024-05-01" def _azure_credential(self, cfg: dict[str, str]): """Return an Azure credential from config (SP preferred, else DefaultAzureCredential).""" @@ -310,8 +310,14 @@ def create_workspace(self, cfg: dict[str, str], ws_name: str, region: str, accou while time.time() < deadline: time.sleep(poll_interval) get_req = urllib.request.Request(url, headers={"Authorization": f"Bearer {token}"}) - with urllib.request.urlopen(get_req) as resp: - data = json.loads(resp.read()) + try: + with urllib.request.urlopen(get_req) as resp: + data = json.loads(resp.read()) + except urllib.error.HTTPError as e: + if e.code >= 500: + print(f" Transient error ({e.code}), retrying...") + continue + raise props = data.get("properties", {}) state = props.get("provisioningState", "Unknown") elapsed = int(time.time() - (deadline - 600)) From 1c228f0d633a35f84c93397f927a0895c0e0bb68 Mon Sep 17 00:00:00 2001 From: Kavya Parashar Date: Mon, 20 Apr 2026 15:14:08 +0530 Subject: [PATCH 3/6] fix: bump ARM API version to 2025-10-01-preview for serverless workspaces Switch from 2024-05-01 (GA) to 2025-10-01-preview, the latest published ARM API version for Microsoft.Databricks/workspaces. Both support computeMode=Serverless; preview includes the most recent schema updates. Applied in shared/scripts/cloud_providers/azure_provider.py and shared/examples/aus_bank_demo/setup_demo.py. Co-authored-by: Isaac --- shared/examples/aus_bank_demo/setup_demo.py | 2 +- shared/scripts/cloud_providers/azure_provider.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/shared/examples/aus_bank_demo/setup_demo.py b/shared/examples/aus_bank_demo/setup_demo.py index 8410e7b..20f7767 100644 --- a/shared/examples/aus_bank_demo/setup_demo.py +++ b/shared/examples/aus_bank_demo/setup_demo.py @@ -564,7 +564,7 @@ def _create_prod_workspace(cfg: dict, cloud: str, metastore_id: str, dev_state: client_secret=cfg.get("AZURE_CLIENT_SECRET", ""), ) arm_token = arm_cred.get_token("https://management.azure.com/.default").token - arm_api_version = "2024-05-01" + arm_api_version = "2025-10-01-preview" arm_url = ( f"https://management.azure.com/subscriptions/{subscription_id}" diff --git a/shared/scripts/cloud_providers/azure_provider.py b/shared/scripts/cloud_providers/azure_provider.py index c00beb2..e4959f3 100644 --- a/shared/scripts/cloud_providers/azure_provider.py +++ b/shared/scripts/cloud_providers/azure_provider.py @@ -244,7 +244,7 @@ def workspace_create_kwargs(self, region: str) -> dict: return {"location": region} # ARM preview API version that supports computeMode=Serverless - _ARM_API_VERSION = "2024-05-01" + _ARM_API_VERSION = "2025-10-01-preview" def _azure_credential(self, cfg: dict[str, str]): """Return an Azure credential from config (SP preferred, else DefaultAzureCredential).""" From 1db1aa23e50c7778aa4ce687611b2a2b3ef06ac0 Mon Sep 17 00:00:00 2001 From: Kavya Parashar Date: Mon, 20 Apr 2026 18:07:46 +0530 Subject: [PATCH 4/6] feat: add CVV and AML Risk Flag identifiers to financial_services overlay CVV columns (cvv, cvc, security_code) were missing from the identifiers list, so the LLM had no guidance to tag or mask them. Similarly, AML risk flag columns (aml_risk_flag, aml_flag) had a row filter function but no identifier definition with column hints. Adds: - CVV identifier + mask_cvv_redact (full redaction per PCI-DSS) - AML Risk Flag identifier + mask_aml_flag (restricted to compliance/fraud) - Updated prompt_overlay with detection hints for both Co-authored-by: Isaac --- shared/industries/financial_services.yaml | 49 +++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/shared/industries/financial_services.yaml b/shared/industries/financial_services.yaml index b04d386..54f6a67 100644 --- a/shared/industries/financial_services.yaml +++ b/shared/industries/financial_services.yaml @@ -80,6 +80,30 @@ identifiers: masking_function: mask_email category: customer_pii + - name: Card Verification Value (CVV) + column_hints: + - cvv + - cvc + - card_verification_value + - security_code + - cvv2 + format: "3-4 digits" + sensitivity: restricted + masking_function: mask_cvv_redact + category: payment_card + + - name: AML Risk Flag + column_hints: + - aml_risk_flag + - aml_flag + - risk_flag + - aml_status + - compliance_flag + format: "Categorical (CLEAR, REVIEW, HIGH_RISK, BLOCKED)" + sensitivity: restricted + masking_function: mask_aml_flag + category: compliance_risk + masking_functions: - name: mask_account_last4 signature: "mask_account_last4(acct STRING) RETURNS STRING" @@ -144,6 +168,21 @@ masking_functions: ELSE '[REDACTED]' END + - name: mask_cvv_redact + signature: "mask_cvv_redact(cvv STRING) RETURNS STRING" + comment: "CVV/CVC — full redaction, PCI DSS prohibits any display or storage" + body: "CASE WHEN cvv IS NULL THEN NULL ELSE '[REDACTED]' END" + + - name: mask_aml_flag + signature: "mask_aml_flag(flag STRING) RETURNS STRING" + comment: "AML risk flag — visible only to compliance/fraud teams, masked for others" + body: | + CASE + WHEN flag IS NULL THEN NULL + WHEN is_account_group_member('fraud_team') OR is_account_group_member('compliance_officer') THEN flag + ELSE '[RESTRICTED]' + END + - name: filter_aml_compliance signature: "filter_aml_compliance() RETURNS BOOLEAN" comment: "AML row filter — only compliance and fraud teams see all rows" @@ -194,11 +233,18 @@ prompt_overlay: | Use `mask_routing` — show last 4 digits. - Credit Card Number (PAN): 13-19 digits. Columns: `card_number`, `credit_card`, `pan`. Use `mask_card_last4` — PCI DSS compliant, last 4 digits only. + - Card Verification Value (CVV/CVC): 3-4 digits. Columns: `cvv`, `cvc`, `security_code`. + Use `mask_cvv_redact` — PCI DSS prohibits any display; always fully redact. **Transaction Data:** - Transaction Amount: Decimal currency. Columns: `transaction_amount`, `txn_amount`, `amount`. Use `mask_amount_round` — round to nearest thousand for non-privileged users. + **Compliance / Risk:** + - AML Risk Flag: Categorical. Columns: `aml_risk_flag`, `aml_flag`, `risk_flag`. + Use `mask_aml_flag` — restricted to fraud_team and compliance_officer groups. + Values like HIGH_RISK and BLOCKED must not be visible to general users. + **Customer PII:** - SSN: 9 digits. Columns: `ssn`, `social_security`, `tax_id`. Use `mask_ssn_last4` — show last 4 digits only. @@ -214,6 +260,8 @@ prompt_overlay: | - `mask_amount_round(amount DOUBLE) RETURNS STRING` — rounded to nearest thousand - `mask_ssn_last4(ssn STRING) RETURNS STRING` — last 4 digits visible - `mask_name(name STRING) RETURNS STRING` — first initial only + - `mask_cvv_redact(cvv STRING) RETURNS STRING` — full redaction (PCI DSS) + - `mask_aml_flag(flag STRING) RETURNS STRING` — visible to compliance/fraud only **Suggested Group Structure:** - `fraud_team` / `compliance_officer`: Full access to all data (AML/SOX compliance) @@ -224,6 +272,7 @@ prompt_overlay: | **Regulatory Context:** - PCI DSS: Card numbers (PAN) must never be displayed in full to non-privileged users. Only first 6 and last 4 digits may be shown; prefer showing only last 4. + CVV/CVC must NEVER be displayed or stored — always use full redaction. - SOX: Financial data access must be auditable. Consider audit logging. - GLBA: Customer financial information must be protected from unauthorized access. - BSA/AML: Transaction monitoring teams (fraud_team) need full access to detect From dc183f1d87d857b62127a296fe16dd2471d22ce3 Mon Sep 17 00:00:00 2001 From: Kavya Parashar Date: Mon, 20 Apr 2026 22:56:45 +0530 Subject: [PATCH 5/6] fix: add owner tag to Azure resources + grant admin group metastore privileges Azure policy at db_fe management group requires an 'owner' tag on all resources. Added the tag to storage accounts, access connectors, and workspaces. Also grant the test admin group CREATE_CATALOG and CREATE_EXTERNAL_LOCATION on the metastore so members can see and manage UC objects without transferring metastore ownership from the SP. Co-authored-by: Isaac --- shared/examples/aus_bank_demo/setup_demo.py | 3 ++- .../scripts/cloud_providers/azure_provider.py | 8 +++--- shared/scripts/provision_test_env.py | 25 +++++++++++++++---- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/shared/examples/aus_bank_demo/setup_demo.py b/shared/examples/aus_bank_demo/setup_demo.py index 20f7767..6c66777 100644 --- a/shared/examples/aus_bank_demo/setup_demo.py +++ b/shared/examples/aus_bank_demo/setup_demo.py @@ -572,11 +572,12 @@ def _create_prod_workspace(cfg: dict, cloud: str, metastore_id: str, dev_state: f"/providers/Microsoft.Databricks/workspaces/{prod_ws_name}" f"?api-version={arm_api_version}" ) + owner = cfg.get("AZURE_CLIENT_ID", cfg.get("DATABRICKS_CLIENT_ID", "unknown")) arm_body = json.dumps({ "location": region, "sku": {"name": "premium"}, "properties": {"computeMode": "Serverless"}, - "tags": {"ManagedBy": "setup_demo"}, + "tags": {"ManagedBy": "setup_demo", "owner": owner}, }).encode() req = urllib.request.Request(arm_url, data=arm_body, method="PUT", headers={ diff --git a/shared/scripts/cloud_providers/azure_provider.py b/shared/scripts/cloud_providers/azure_provider.py index e4959f3..edfd29d 100644 --- a/shared/scripts/cloud_providers/azure_provider.py +++ b/shared/scripts/cloud_providers/azure_provider.py @@ -74,6 +74,7 @@ def setup_storage(self, cfg: dict[str, str], run_id: str, region: str, account_i # Storage account name: must be globally unique, 3-24 lowercase alphanumeric sa_name = f"genietest{run_id}"[:24].lower() container_name = "genie-test" + owner = cfg.get("AZURE_CLIENT_ID", cfg.get("DATABRICKS_CLIENT_ID", "unknown")) _step(f"Creating Azure Storage Account: {sa_name}") storage_client = StorageManagementClient(credential, subscription_id) @@ -86,7 +87,7 @@ def setup_storage(self, cfg: dict[str, str], run_id: str, region: str, account_i kind=Kind.STORAGE_V2, location=region, is_hns_enabled=True, # ADLS Gen2 (hierarchical namespace) - tags={"ManagedBy": "provision_test_env", "RunId": run_id}, + tags={"ManagedBy": "provision_test_env", "RunId": run_id, "owner": owner}, ), ) poller.result() @@ -113,7 +114,7 @@ def setup_storage(self, cfg: dict[str, str], run_id: str, region: str, account_i AccessConnector( location=region, identity=ManagedServiceIdentity(type="SystemAssigned"), - tags={"ManagedBy": "provision_test_env", "RunId": run_id}, + tags={"ManagedBy": "provision_test_env", "RunId": run_id, "owner": owner}, ), ) ac = poller.result() @@ -276,6 +277,7 @@ def create_workspace(self, cfg: dict[str, str], ws_name: str, region: str, accou subscription_id = cfg["AZURE_SUBSCRIPTION_ID"] resource_group = cfg["AZURE_RESOURCE_GROUP"] token = self._arm_token(cfg) + owner = cfg.get("AZURE_CLIENT_ID", cfg.get("DATABRICKS_CLIENT_ID", "unknown")) url = ( f"https://management.azure.com/subscriptions/{subscription_id}" @@ -287,7 +289,7 @@ def create_workspace(self, cfg: dict[str, str], ws_name: str, region: str, accou "location": region, "sku": {"name": "premium"}, "properties": {"computeMode": "Serverless"}, - "tags": {"ManagedBy": "provision_test_env"}, + "tags": {"ManagedBy": "provision_test_env", "owner": owner}, }).encode() _step(f"Creating serverless Azure Databricks workspace: {ws_name}") diff --git a/shared/scripts/provision_test_env.py b/shared/scripts/provision_test_env.py index 8d60768..bad30d4 100644 --- a/shared/scripts/provision_test_env.py +++ b/shared/scripts/provision_test_env.py @@ -1143,6 +1143,21 @@ def cmd_provision(cfg: dict[str, str], dry_run: bool = False, force: bool = Fals except Exception as exc: _warn(f"Could not grant metastore privileges (SP may already have them as creator): {exc}") + # Also grant metastore privileges to the admin group so its members + # can see and manage catalogs, schemas, and tables. + try: + w_admin.grants.update( + securable_type="metastore", + full_name=ms_id, + changes=[PermissionsChange( + principal=group_name, + add=[Privilege.CREATE_CATALOG, Privilege.CREATE_EXTERNAL_LOCATION], + )], + ) + _ok(f"Metastore privileges granted to admin group: {group_name}") + except Exception as exc: + _warn(f"Could not grant metastore privileges to admin group: {exc}") + # ------------------------------------------------------------------ # Step 5c: Enable Partner Powered AI and warm up the Genie API. # @@ -1246,11 +1261,11 @@ def cmd_provision(cfg: dict[str, str], dry_run: bool = False, force: bool = Fals print(" Catalog creation requires an External Location. Aborting.") sys.exit(1) - # Note: we intentionally do NOT transfer metastore ownership to the admin - # group here. The SP (account admin and metastore creator) retains its - # implicit metastore admin status, which grants CREATE CATALOG and all - # other UC privileges needed for the integration tests. Transferring - # ownership would strip those rights and cause catalog creation to fail. + # Note: we do NOT transfer metastore *ownership* to the admin group — + # that would strip the SP's implicit admin rights and break catalog + # creation. Instead, the admin group receives explicit CREATE_CATALOG + # and CREATE_EXTERNAL_LOCATION grants (Step 5a-2 above) so its members + # can see and manage UC objects without needing metastore ownership. # Workspace admin assignment already done in Step 5a above. From d1d6556c3a76d0ad0022352597a1b39f0a7bc723 Mon Sep 17 00:00:00 2001 From: Kavya Parashar Date: Mon, 20 Apr 2026 23:17:38 +0530 Subject: [PATCH 6/6] fix: grant CAN_USE on SQL warehouse to all ABAC groups Without this, groups assigned to the Genie Space get "You do not have permission to use the SQL Warehouse" because no databricks_permissions resource existed for the warehouse. Co-authored-by: Isaac --- shared/modules/workspace/main.tf | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/shared/modules/workspace/main.tf b/shared/modules/workspace/main.tf index 4a06083..f6adc41 100644 --- a/shared/modules/workspace/main.tf +++ b/shared/modules/workspace/main.tf @@ -95,6 +95,28 @@ resource "databricks_sql_endpoint" "warehouse" { auto_stop_mins = 15 } +# ── Grant CAN_USE on the SQL warehouse to every ABAC group ────────────────── + +resource "databricks_permissions" "warehouse_usage" { + count = var.genie_only ? 0 : 1 + + provider = databricks.workspace + sql_endpoint_id = local.shared_warehouse_id + + dynamic "access_control" { + for_each = local.group_ids + content { + group_name = access_control.key + permission_level = "CAN_USE" + } + } + + depends_on = [ + databricks_sql_endpoint.warehouse, + databricks_entitlements.group_entitlements, + ] +} + # ── Existing spaces: apply ACLs + config (when config is defined) ───────────── resource "null_resource" "genie_space_acls" {