-
Notifications
You must be signed in to change notification settings - Fork 1
Ensemble fix: decouple ID extraction from drop operation #124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
43b97f1
a67a0cf
0ed74e9
007e33b
927eb47
35fc199
697a592
d2351cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| from midst_toolkit.attacks.ensemble.data_utils import load_dataframe | ||
| from midst_toolkit.common.logger import log | ||
| from midst_toolkit.common.random import set_all_random_seeds | ||
| from midst_toolkit.models.clavaddpm.train import get_df_without_id | ||
|
|
||
|
|
||
| class RmiaTrainingDataChoice(Enum): | ||
|
|
@@ -51,37 +52,34 @@ def save_results( | |
| f.write(f"TPR at FPR=0.1: {pred_score:.4f}\n") | ||
|
|
||
|
|
||
| def extract_and_drop_id_column( | ||
| def extract_the_main_id_column( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Super minor, but maybe "primary_id" instead of "main_id"? Primary keys are the unique keys in databases. So it might be more suggestive to a reader 🙂
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I knew there was a better name, but I couldn't remember at the time lol. Thanks!
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suuuper nit-picky, but can we rename the function name to extract_main_id_column (drop the "the")? Sounds better
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! |
||
| data_frame: pd.DataFrame, | ||
| data_types_file_path: Path, | ||
| ) -> tuple[pd.DataFrame, pd.Series]: | ||
| ) -> pd.Series: | ||
| """ | ||
| Extracts IDs from the dataframe and drops the ID column. ID column is identified based on | ||
| Extracts and returns the main IDs from the dataframe. The main ID column is identified based on | ||
| the data types JSON file with "id_column_name" key. | ||
| Main IDs are not repeated in the dataset. | ||
| For example, in the Berka dataset, "trans_id" is the main ID column, and "account_id" is not the main ID column. | ||
|
|
||
| Args: | ||
| data_frame: Input dataframe. | ||
| data_types_file_path: Path to the data types JSON file. | ||
|
|
||
| Returns: | ||
| A tuple containing: | ||
| - The modified dataframe with ID columns dropped. | ||
| - A Series containing the extracted data of ID columns. | ||
| A Series containing the extracted data of the main ID column. | ||
| """ | ||
| # Extract ID column from the dataframe | ||
| with open(data_types_file_path, "r") as f: | ||
| column_types = json.load(f) | ||
|
|
||
| assert "id_column_name" in column_types, f"{data_types_file_path} must contain 'id_column_name' key." | ||
| id_column_name = column_types["id_column_name"] | ||
|
|
||
| # Make sure we have one main id column | ||
| assert isinstance(id_column_name, str), "Only one main id column should be identified." | ||
| assert id_column_name in data_frame.columns, f"Dataframe must have {id_column_name} column" | ||
| data_trans_ids = data_frame[id_column_name] | ||
|
|
||
| # Drop ID column from data | ||
| data_frame = data_frame.drop(columns=id_column_name) | ||
|
|
||
| return data_frame, data_trans_ids | ||
| return data_frame[id_column_name] | ||
|
|
||
|
|
||
| def run_rmia_shadow_training(config: DictConfig, df_challenge: pd.DataFrame) -> list[dict[str, list[Any]]]: | ||
|
|
@@ -183,7 +181,7 @@ def collect_challenge_and_train_data( | |
| midst_data_input_dir=targets_data_path, | ||
| attack_types=challenge_attack_types, | ||
| # For ensemble experiments, change to ``test`` for 10k, and change to ``final`` for 20k | ||
| split_folders=["test"], | ||
| split_folders=["final"], | ||
| dataset="challenge", | ||
| data_processing_config=data_processing_config, | ||
| ) | ||
|
|
@@ -266,7 +264,7 @@ def train_rmia_shadows_for_test_phase(config: DictConfig) -> list[dict[str, list | |
| df_challenge_experiment, df_master_train = collect_challenge_and_train_data( | ||
| config.data_processing_config, | ||
| processed_attack_data_path=Path(config.data_paths.processed_attack_data_path), | ||
| targets_data_path=Path(config.data_paths.midst_data_path), | ||
| targets_data_path=Path(config.data_processing_config.midst_data_path), | ||
| ) | ||
| # Load the challenge dataframe for training RMIA shadow models. | ||
| rmia_training_choice = RmiaTrainingDataChoice(config.target_model.attack_rmia_shadow_training_data_choice) | ||
|
|
@@ -357,8 +355,13 @@ def run_metaclassifier_testing( | |
| else: | ||
| log(INFO, "All shadow models for testing phase found. Using existing RMIA shadow models...") | ||
|
|
||
| # Extract and drop id columns from the test data | ||
| test_data, test_trans_ids = extract_and_drop_id_column(test_data, Path(config.metaclassifier.data_types_file_path)) | ||
| # Extract the main ID column's values from the test data | ||
| test_trans_ids = extract_the_main_id_column( | ||
| data_frame=test_data, | ||
| data_types_file_path=Path(config.metaclassifier.data_types_file_path), | ||
| ) | ||
| # Drop id columns from the test data. Berka has two id columns: "trans_id" and "account_id". | ||
| test_data = get_df_without_id(test_data) | ||
|
|
||
| # 4) Initialize the attacker object, and assign the loaded metaclassifier to it. | ||
| blending_attacker = BlendingPlusPlus( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One step towards config modularization.