From d434e5cd7363f59c757e10142cd3b1bb549a8f65 Mon Sep 17 00:00:00 2001 From: "Hyeong Kyun (Daniel) Park" Date: Wed, 4 Mar 2026 06:42:01 -0500 Subject: [PATCH 1/3] Refactor dataset paths and adjust training parameters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - assertion errors when true_concepts don’t match tmp_concepts (due to dataloader shuffling) - batch sizes > 1 cause shape mismatches --- rsseval/rss/utils/tcav/tcav/main.py | 139 +++++++++++++++++++++++----- 1 file changed, 114 insertions(+), 25 deletions(-) diff --git a/rsseval/rss/utils/tcav/tcav/main.py b/rsseval/rss/utils/tcav/tcav/main.py index 1b35c64b..a8abaf8c 100644 --- a/rsseval/rss/utils/tcav/tcav/main.py +++ b/rsseval/rss/utils/tcav/tcav/main.py @@ -9,6 +9,12 @@ from collections import OrderedDict from torch.utils.data import DataLoader +import sys +project_root = "/home/park1119/rsbench-code/rsseval/rss" + +if project_root not in sys.path: + sys.path.insert(0, project_root) + torch.multiprocessing.set_sharing_strategy("file_system") from datasets.boia import BOIA @@ -70,7 +76,7 @@ def validate( is_boia = True if dataset_name in ["shortmnist", "clipshortmnist"]: - extract_layer = "conv2" # conv1, conv2, fc1, fc2 + extract_layer = "fc2" # conv1, conv2, fc1, fc2 is_boia = False if dataset_name in ["boia", "clipboia"]: extract_layer = "fc1" # fc1, fc2, fc3, fc4 @@ -153,27 +159,97 @@ def get_dataset(datasetname, args): def setup(): + # WARNING: BATCH SIZE MUST BE 1 FOR TCAV EVALUATION + # Set validate=0 because validating? + + # CHANGE HERE + + # args = Namespace( + # backbone="neural", # + # preprocess=0, + # finetuning=0, + # batch_size=128, + # n_epochs=40, + # validate=0, # validate=0 means evaluate on test, not validation (train.py) + # dataset="shortmnist", + # lr=0.001, + # exp_decay=0.99, + # warmup_steps=0, # this is default even when it was first trained + # wandb=None, + # task="addition", + # model="mnistnn", + # c_sup=0, + # which_c=[-1], + # joint=False # True only for CLIP + # ) + + # args = Namespace( + # backbone="neural", # + # preprocess=0, + # finetuning=0, + # batch_size=1, + # n_epochs=40, + # validate=0, + # dataset="clipshortmnist", + # lr=0.001, + # exp_decay=0.99, + # weight_decay=0.01, + # warmup_steps=0, + # wandb=None, + # task="addition", + # model="mnistnn", # + # c_sup=0, + # which_c=[-1], + # joint=True, + # ) + + + # EDIT HERE FOR DIFFERENT SETTINGS args = Namespace( - backbone="neural", # "conceptizer", + backbone="neural", # preprocess=0, finetuning=0, - batch_size=1, - n_epochs=20, - validate=1, - dataset="clipsddoia", + batch_size=32, + n_epochs=40, + validate=0, # validate=0 means evaluate on test, not validation (train.py) + dataset="sddoia", lr=0.001, exp_decay=0.99, - warmup_steps=1, + weight_decay=0.0001, + warmup_steps=0, # this is default even when it was first trained wandb=None, task="boia", - boia_model="ce", + boia_model="ce", # ce (cross-entropy) or bce (binary cross-entropy?) keep it at ce model="sddoiann", c_sup=0, - which_c=-1, - joint=True, - boia_ood_knowledge=False, + which_c=[-1], + joint=False, # True only for CLIP + boia_ood_knowledge=False, # case for nesy models to handle ambulance ) + # args = Namespace( + # backbone="neural", # "conceptizer", + # preprocess=0, + # finetuning=0, + # batch_size=1, + # n_epochs=20, + # validate=0, + # dataset="clipsddoia", + # lr=0.001, + # exp_decay=0.99, + # warmup_steps=1, + # wandb=None, + # task="boia", + # boia_model="ce", + # model="sddoiann", + # c_sup=0, + # which_c=-1, + # joint=True, + # boia_ood_knowledge=False, + # ) + + + # get dataset dataset = get_dataset(args.dataset, args) # get model @@ -232,12 +308,13 @@ def mnist_tcav_setup(): ] tmp_concept_dict = {} - for dirname in os.listdir("../data/concepts"): - fullpath = os.path.join("../data/concepts", dirname) + for dirname in os.listdir("../../../data/concepts/mnist"): + fullpath = os.path.join("../../../data/concepts/mnist", dirname) if os.path.isdir(fullpath): tmp_concept_dict[dirname] = data_loader(fullpath, args.dataset) concept_dict = OrderedDict() + print(tmp_concept_dict.keys()) for c in concepts_order: concept_dict[c] = tmp_concept_dict[c] @@ -263,9 +340,9 @@ def kand_tcav_setup(is_clip=False): tmp_concept_dict = {} - lmao_name = "../data/kand-tcav/" + lmao_name = "../../../data/concepts/kand-tcav/" if is_clip: - lmao_name = "../data/kand-tcav-clip/" + lmao_name = "../../../data/concepts/kand-tcav-clip/" for dirname in os.listdir(lmao_name): fullpath = os.path.join(lmao_name, dirname) @@ -320,8 +397,8 @@ def boia_tcav_setup(): ] tmp_concept_dict = {} - for dirname in os.listdir("../data/boia-preprocess-full/concepts/"): - fullpath = os.path.join("../data/boia-preprocess-full/concepts/", dirname) + for dirname in os.listdir("../../../data/concepts/boia-preprocess-full/"): + fullpath = os.path.join("../../../data/concepts/boia-preprocess-full/", dirname) if os.path.isdir(fullpath): tmp_concept_dict[dirname] = data_loader(fullpath, args.dataset) @@ -378,8 +455,8 @@ def sddoia_tcav_setup(full=False): folder_suffix = "-preprocess-full" tmp_concept_dict = {} - for dirname in os.listdir(f"../data/sddoia{folder_suffix}/concepts"): - fullpath = os.path.join(f"../data/sddoia{folder_suffix}/concepts", dirname) + for dirname in os.listdir(f"../../../data/concepts/sddoia{folder_suffix}"): + fullpath = os.path.join(f"../../../data/concepts/sddoia{folder_suffix}", dirname) if os.path.isdir(fullpath): tmp_concept_dict[dirname] = data_loader(fullpath, args.dataset) @@ -407,8 +484,8 @@ def xor_tcav_setup(): ] tmp_concept_dict = {} - for dirname in os.listdir("../data/xor/concepts"): - fullpath = os.path.join("../data/xor/concepts", dirname) + for dirname in os.listdir("../../../data/concepts/xor"): + fullpath = os.path.join("../../../data/concepts/xor", dirname) if os.path.isdir(fullpath): for i in range(4): @@ -524,8 +601,8 @@ def mnmath_tcav_setup(): "xxxxxxx9", ] tmp_concept_dict = {} - for dirname in os.listdir("../data/concepts"): - fullpath = os.path.join("../data/concepts", dirname) + for dirname in os.listdir("../../../data/concepts/mnmath"): + fullpath = os.path.join("../../../data/concepts/mnmath", dirname) if os.path.isdir(fullpath): for i in range(8): concept_name = '' @@ -566,7 +643,9 @@ def mnmath_tcav_setup(): # get everything args, dataset, model = setup() - seeds = [123, 456, 789, 1011, 1213] + # CHANGE HERE + seeds = [123, 456, 789, 1011, 1213, 1415, 1617, 1819, 2021, 2223] + #seeds = [1415, 1617, 1819, 2021, 2223] model_path = f"best_model_{args.dataset}_{args.model}" sddoia_full = "" to_add = "" # "_padd_random" @@ -577,7 +656,9 @@ def mnmath_tcav_setup(): print("Doing seed", seed) - current_model_path = f"{model_path}_{seed}.pth" + # CHANGE HERE + current_model_path = f"../../../jobs/{args.dataset}_nn_run_nocsup/{model_path}_{seed}.pth" + print(f"Loading model from {current_model_path}") if not os.path.exists(current_model_path): print(f"{current_model_path} is missing...") @@ -608,6 +689,14 @@ def mnmath_tcav_setup(): to_add = "_full" validloader, class_dict, concept_dict = sddoia_tcav_setup() + # added by daniel: reset validloader with batch_size=1 and shuffle=False + validloader = torch.utils.data.DataLoader( + validloader.dataset, + batch_size=1, + shuffle=False, # Disable shuffling for consistent ordering + num_workers=4, + ) + validate( model, args.dataset, From 232a85c7801cba1bef03b15a3dd06bcc012a6c17 Mon Sep 17 00:00:00 2001 From: "Hyeong Kyun (Daniel) Park" Date: Wed, 4 Mar 2026 06:43:18 -0500 Subject: [PATCH 2/3] Add script to generate MNIST concept folders The script generates concept folders for the SHORTMNIST dataset by extracting individual digits from 2-digit images and saving them in corresponding directories. --- .../tcav/tcav/generate_mnist_concepts.py | 130 ++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 rsseval/rss/utils/tcav/tcav/generate_mnist_concepts.py diff --git a/rsseval/rss/utils/tcav/tcav/generate_mnist_concepts.py b/rsseval/rss/utils/tcav/tcav/generate_mnist_concepts.py new file mode 100644 index 00000000..a8b1ec3f --- /dev/null +++ b/rsseval/rss/utils/tcav/tcav/generate_mnist_concepts.py @@ -0,0 +1,130 @@ +import os +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +from PIL import Image +import numpy as np + +import sys +project_root = "/home/park1119/rsbench-code/rsseval/rss" +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from datasets.shortcutmnist import SHORTMNIST +from argparse import Namespace + +def generate_mnist_concepts(output_base="data/concepts/mnist"): + """Generate concept folders for shortmnist using the strategy: + - For concept "1", save samples where first digit is 1 (labels like 1x) + - These will be padded differently during TCAV evaluation to isolate the concept + """ + + # Setup dataset + args = Namespace( + backbone="neural", # + preprocess=0, + finetuning=0, + batch_size=32, + n_epochs=40, + validate=0, # validate=0 means evaluate on test, not validation (train.py) + dataset="shortmnist", + lr=0.001, + exp_decay=0.99, + warmup_steps=0, # this is default even when it was first trained + wandb=None, + task="addition", + model="mnistnn", + c_sup=1, # Set to 1 to get actual digit labels instead of -1 + which_c=[-1], + joint=False # True only for CLIP + ) + + # args = Namespace( + # backbone="neural", + # preprocess=0, + # finetuning=0, + # batch_size=1, + # n_epochs=40, + # validate=0, + # dataset="shortmnist", + # lr=0.001, + # exp_decay=0.99, + # warmup_steps=0, + # wandb=None, + # model="mnistnn" + # ) + + dataset = SHORTMNIST(args) + train_loader, _, _ = dataset.get_data_loaders() + + # Load the RAW unfiltered data to get digit labels + raw_data_path = os.path.join(project_root, "datasets/utils/2mnist_10digits/2mnist_10digits.pt") + raw_data = torch.load(raw_data_path, weights_only=False) + raw_labels = raw_data['train']['labels'] # Shape: (42000, 3) with [digit1, digit2, sum] + print(f"Raw labels shape: {raw_labels.shape}") + + # Also need to load train indexes to reverse map from actual images to digit pairs + train_indexes = torch.load(os.path.join(project_root, "datasets/utils/2mnist_10digits/train_indexes.pt"), weights_only=False) + + # Create a mapping from image to its digit pair + # The indexes structure is: {(digit1, digit2): [list of indices in raw dataset]} + idx_to_digits = {} + for (d1, d2), indices in train_indexes.items(): + for idx in indices: + idx_to_digits[idx.item()] = (d1, d2) + + # Access the dataset directly + train_dataset = dataset.dataset_train + + # Create concept directories (one for each digit 0-9) + for digit in range(10): + concept_dir = os.path.join(project_root, output_base, str(digit)) + os.makedirs(concept_dir, exist_ok=True) + + # Process dataset and save images by concept (iterating directly, not via dataloader) + img_count = {i: 0 for i in range(10)} + + print(f"Processing {len(train_dataset)} images...") + print("Extracting individual digits from 2-digit images...") + + for idx in range(len(train_dataset)): + # Get image and concepts directly from dataset + img, label, concepts = train_dataset[idx] + + # Use concepts to get the actual digit pair for this image + first_digit = int(concepts[0]) + second_digit = int(concepts[1]) + + # img is shape (1, 28, 56) - grayscale, height=28, width=56 (two 28x28 digits side by side) + # Extract left half (first digit) and right half (second digit) + img_pil = transforms.ToPILImage()(img) # Convert to PIL + + # Convert to numpy for easier slicing + img_array = np.array(img_pil) # Shape: (28, 56) + + # Extract first digit (left half, columns 0-27) + first_digit_array = img_array[:, 0:28] + first_digit_pil = Image.fromarray(first_digit_array.astype(np.uint8)) + + # Extract second digit (right half, columns 28-55) + second_digit_array = img_array[:, 28:56] + second_digit_pil = Image.fromarray(second_digit_array.astype(np.uint8)) + + # Save first digit to its concept folder + concept_dir_first = os.path.join(project_root, output_base, str(first_digit)) + img_path_first = os.path.join(concept_dir_first, f"{img_count[first_digit]:06d}.png") + first_digit_pil.save(img_path_first) + img_count[first_digit] += 1 + + # Save second digit to its concept folder + concept_dir_second = os.path.join(project_root, output_base, str(second_digit)) + img_path_second = os.path.join(concept_dir_second, f"{img_count[second_digit]:06d}.png") + second_digit_pil.save(img_path_second) + img_count[second_digit] += 1 + + if idx % 5000 == 0: + print(f"Processed {idx} images, counts: {img_count}") + +if __name__ == "__main__": + generate_mnist_concepts() + print("Concept generation complete!") From ae7a2467c1cf68c596e8d4a08f981d27598f8b0e Mon Sep 17 00:00:00 2001 From: "Hyeong Kyun (Daniel) Park" Date: Wed, 4 Mar 2026 06:46:15 -0500 Subject: [PATCH 3/3] Update README with NN/CLIP evaluation instructions Added instructions for evaluating NN/CLIP models using TCAV and analysis notebooks. --- rsseval/README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/rsseval/README.md b/rsseval/README.md index 43c82726..a2e5c355 100644 --- a/rsseval/README.md +++ b/rsseval/README.md @@ -133,6 +133,12 @@ To evaluate different models or datasets, follow this pattern: To evaluate your model, start by training several instances with different seed values. This will ensure a robust evaluation by averaging results across various seeds. We provide an easy-to-use notebook in the `notebooks` directory for this purpose. You can find the evaluation notebook [here](rss/notebooks/evaluate.ipynb). Simply follow the instructions within the notebook to assess your model's performance. +For NN/CLIP models, concept-level metrics must be extracted this way instead of using evaluate.ipynb: +- Train model +- Run TCAV main.py +- Run analysis.ipynb +- Extract Concept Acc, F1, Collapse + ## Hyperparameter Tuning Our repository also supports hyperparameter tuning using a Bayesian search strategy. To begin tuning, use the `--tuning` flag: @@ -218,4 +224,4 @@ cd docs; make html ## Libraries and extra tools -This code is adapted from [Marconato et al. (2024) bears](https://github.com/samuelebortolotti/bears). \ No newline at end of file +This code is adapted from [Marconato et al. (2024) bears](https://github.com/samuelebortolotti/bears).