22# Copyright (c) 2022 Anonos IP LLC.
33# See https://github.com/statice/anonymeter/blob/main/LICENSE.md for details.
44"""Privacy evaluator that measures the inference risk."""
5-
65from typing import Optional
76
87import numpy as np
98import numpy .typing as npt
109import pandas as pd
1110
12- from anonymeter .neighbors .mixed_types_kneighbors import MixedTypeKNeighbors
11+ from anonymeter .evaluators .inference_predictor import InferencePredictor
12+ from anonymeter .neighbors .mixed_types_kneighbors import KNNInferencePredictor
1313from anonymeter .stats .confidence import EvaluationResults , PrivacyRisk
1414
1515
1616def _run_attack (
17- target : pd .DataFrame ,
18- syn : pd .DataFrame ,
19- n_attacks : int ,
20- aux_cols : list [str ],
21- secret : str ,
22- n_jobs : int ,
23- naive : bool ,
24- regression : Optional [bool ],
17+ target : pd .DataFrame ,
18+ syn : pd .DataFrame ,
19+ n_attacks : int ,
20+ aux_cols : list [str ],
21+ secret : str ,
22+ n_jobs : int ,
23+ naive : bool ,
24+ regression : Optional [bool ],
25+ inference_model : Optional [InferencePredictor ],
2526) -> int :
2627 if regression is None :
2728 regression = pd .api .types .is_numeric_dtype (target [secret ])
@@ -30,21 +31,17 @@ def _run_attack(
3031
3132 if naive :
3233 guesses = syn .sample (n_attacks )[secret ]
33-
3434 else :
35- nn = MixedTypeKNeighbors (n_jobs = n_jobs , n_neighbors = 1 ).fit (candidates = syn [aux_cols ])
36-
37- guesses_idx = nn .kneighbors (queries = targets [aux_cols ])
38- if isinstance (guesses_idx , tuple ):
39- raise RuntimeError ("guesses_idx cannot be a tuple" )
40-
41- guesses = syn .iloc [guesses_idx .flatten ()][secret ]
35+ # Instantiate the default KNN model if no other model is passed through `inference_model`.
36+ if inference_model is None :
37+ inference_model = KNNInferencePredictor (data = syn , columns = aux_cols , target_col = secret , n_jobs = n_jobs )
38+ guesses = inference_model .predict (targets )
4239
4340 return evaluate_inference_guesses (guesses = guesses , secrets = targets [secret ], regression = regression ).sum ()
4441
4542
4643def evaluate_inference_guesses (
47- guesses : pd .Series , secrets : pd .Series , regression : bool , tolerance : float = 0.05
44+ guesses : pd .Series , secrets : pd .Series , regression : bool , tolerance : float = 0.05
4845) -> npt .NDArray :
4946 """Evaluate the success of an inference attack.
5047
@@ -142,23 +139,33 @@ class InferenceEvaluator:
142139 the variable.
143140 n_attacks : int, default is 500
144141 Number of attack attempts.
142+ In case the whole dataset size should be used, set this to np.inf.
143+ inference_model: InferencePredictor
144+ An ml model fitted on `syn` as training data, and `secret` as target, that supports ::predict(x).
145+ If not None, it will be used over the MixedTypeKNeighbors in the attack.
145146
146147 """
147148
148149 def __init__ (
149- self ,
150- ori : pd .DataFrame ,
151- syn : pd .DataFrame ,
152- aux_cols : list [str ],
153- secret : str ,
154- regression : Optional [bool ] = None ,
155- n_attacks : int = 500 ,
156- control : Optional [pd .DataFrame ] = None ,
150+ self ,
151+ ori : pd .DataFrame ,
152+ syn : pd .DataFrame ,
153+ aux_cols : list [str ],
154+ secret : str ,
155+ regression : Optional [bool ] = None ,
156+ n_attacks : int = 500 ,
157+ control : Optional [pd .DataFrame ] = None ,
158+ inference_model : Optional [InferencePredictor ] = None
157159 ):
158160 self ._ori = ori
159161 self ._syn = syn
160162 self ._control = control
161163 self ._n_attacks = n_attacks
164+ self ._inference_model = inference_model
165+
166+ self ._n_attacks_ori = min (n_attacks , self ._ori .shape [0 ])
167+ self ._n_attacks_baseline = min (self ._syn .shape [0 ], self ._n_attacks_ori )
168+ self ._n_attacks_control = - 1 if self ._control is None else min (n_attacks , self ._control .shape [0 ])
162169
163170 # check if secret is a string column
164171 if not isinstance (secret , str ):
@@ -173,16 +180,17 @@ def __init__(
173180 self ._aux_cols = aux_cols
174181 self ._evaluated = False
175182
176- def _attack (self , target : pd .DataFrame , naive : bool , n_jobs : int ) -> int :
183+ def _attack (self , target : pd .DataFrame , naive : bool , n_jobs : int , n_attacks : int ) -> int :
177184 return _run_attack (
178185 target = target ,
179186 syn = self ._syn ,
180- n_attacks = self . _n_attacks ,
187+ n_attacks = n_attacks ,
181188 aux_cols = self ._aux_cols ,
182189 secret = self ._secret ,
183190 n_jobs = n_jobs ,
184191 naive = naive ,
185192 regression = self ._regression ,
193+ inference_model = self ._inference_model ,
186194 )
187195
188196 def evaluate (self , n_jobs : int = - 2 ) -> "InferenceEvaluator" :
@@ -199,11 +207,14 @@ def evaluate(self, n_jobs: int = -2) -> "InferenceEvaluator":
199207 The evaluated ``InferenceEvaluator`` object.
200208
201209 """
202- self ._n_baseline = self ._attack (target = self ._ori , naive = True , n_jobs = n_jobs )
203- self ._n_success = self ._attack (target = self ._ori , naive = False , n_jobs = n_jobs )
210+ self ._n_baseline = self ._attack (target = self ._ori , naive = True , n_jobs = n_jobs ,
211+ n_attacks = self ._n_attacks_baseline )
212+ self ._n_success = self ._attack (target = self ._ori , naive = False , n_jobs = n_jobs ,
213+ n_attacks = self ._n_attacks_ori )
204214 self ._n_control = (
205- None if self ._control is None else self ._attack (target = self ._control , naive = False , n_jobs = n_jobs )
206- )
215+ None if self ._control is None else self ._attack (target = self ._control , naive = False , n_jobs = n_jobs ,
216+ n_attacks = self ._n_attacks_control )
217+ )
207218
208219 self ._evaluated = True
209220 return self
@@ -226,7 +237,7 @@ def results(self, confidence_level: float = 0.95) -> EvaluationResults:
226237 raise RuntimeError ("The inference evaluator wasn't evaluated yet. Please, run `evaluate()` first." )
227238
228239 return EvaluationResults (
229- n_attacks = self ._n_attacks ,
240+ n_attacks = ( self ._n_attacks_ori , self . _n_attacks_baseline , self . _n_attacks_control ) ,
230241 n_success = self ._n_success ,
231242 n_baseline = self ._n_baseline ,
232243 n_control = self ._n_control ,
0 commit comments