Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion rsseval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ To evaluate different models or datasets, follow this pattern:

To evaluate your model, start by training several instances with different seed values. This will ensure a robust evaluation by averaging results across various seeds. We provide an easy-to-use notebook in the `notebooks` directory for this purpose. You can find the evaluation notebook [here](rss/notebooks/evaluate.ipynb). Simply follow the instructions within the notebook to assess your model's performance.

For NN/CLIP models, concept-level metrics must be extracted this way instead of using evaluate.ipynb:
- Train model
- Run TCAV main.py
- Run analysis.ipynb
- Extract Concept Acc, F1, Collapse

## Hyperparameter Tuning

Our repository also supports hyperparameter tuning using a Bayesian search strategy. To begin tuning, use the `--tuning` flag:
Expand Down Expand Up @@ -218,4 +224,4 @@ cd docs; make html

## Libraries and extra tools

This code is adapted from [Marconato et al. (2024) bears](https://github.com/samuelebortolotti/bears).
This code is adapted from [Marconato et al. (2024) bears](https://github.com/samuelebortolotti/bears).
130 changes: 130 additions & 0 deletions rsseval/rss/utils/tcav/tcav/generate_mnist_concepts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

import sys
project_root = "/home/park1119/rsbench-code/rsseval/rss"
if project_root not in sys.path:
sys.path.insert(0, project_root)

from datasets.shortcutmnist import SHORTMNIST
from argparse import Namespace

def generate_mnist_concepts(output_base="data/concepts/mnist"):
"""Generate concept folders for shortmnist using the strategy:
- For concept "1", save samples where first digit is 1 (labels like 1x)
- These will be padded differently during TCAV evaluation to isolate the concept
"""

# Setup dataset
args = Namespace(
backbone="neural", #
preprocess=0,
finetuning=0,
batch_size=32,
n_epochs=40,
validate=0, # validate=0 means evaluate on test, not validation (train.py)
dataset="shortmnist",
lr=0.001,
exp_decay=0.99,
warmup_steps=0, # this is default even when it was first trained
wandb=None,
task="addition",
model="mnistnn",
c_sup=1, # Set to 1 to get actual digit labels instead of -1
which_c=[-1],
joint=False # True only for CLIP
)

# args = Namespace(
# backbone="neural",
# preprocess=0,
# finetuning=0,
# batch_size=1,
# n_epochs=40,
# validate=0,
# dataset="shortmnist",
# lr=0.001,
# exp_decay=0.99,
# warmup_steps=0,
# wandb=None,
# model="mnistnn"
# )

dataset = SHORTMNIST(args)
train_loader, _, _ = dataset.get_data_loaders()

# Load the RAW unfiltered data to get digit labels
raw_data_path = os.path.join(project_root, "datasets/utils/2mnist_10digits/2mnist_10digits.pt")
raw_data = torch.load(raw_data_path, weights_only=False)
raw_labels = raw_data['train']['labels'] # Shape: (42000, 3) with [digit1, digit2, sum]
print(f"Raw labels shape: {raw_labels.shape}")

# Also need to load train indexes to reverse map from actual images to digit pairs
train_indexes = torch.load(os.path.join(project_root, "datasets/utils/2mnist_10digits/train_indexes.pt"), weights_only=False)

# Create a mapping from image to its digit pair
# The indexes structure is: {(digit1, digit2): [list of indices in raw dataset]}
idx_to_digits = {}
for (d1, d2), indices in train_indexes.items():
for idx in indices:
idx_to_digits[idx.item()] = (d1, d2)

# Access the dataset directly
train_dataset = dataset.dataset_train

# Create concept directories (one for each digit 0-9)
for digit in range(10):
concept_dir = os.path.join(project_root, output_base, str(digit))
os.makedirs(concept_dir, exist_ok=True)

# Process dataset and save images by concept (iterating directly, not via dataloader)
img_count = {i: 0 for i in range(10)}

print(f"Processing {len(train_dataset)} images...")
print("Extracting individual digits from 2-digit images...")

for idx in range(len(train_dataset)):
# Get image and concepts directly from dataset
img, label, concepts = train_dataset[idx]

# Use concepts to get the actual digit pair for this image
first_digit = int(concepts[0])
second_digit = int(concepts[1])

# img is shape (1, 28, 56) - grayscale, height=28, width=56 (two 28x28 digits side by side)
# Extract left half (first digit) and right half (second digit)
img_pil = transforms.ToPILImage()(img) # Convert to PIL

# Convert to numpy for easier slicing
img_array = np.array(img_pil) # Shape: (28, 56)

# Extract first digit (left half, columns 0-27)
first_digit_array = img_array[:, 0:28]
first_digit_pil = Image.fromarray(first_digit_array.astype(np.uint8))

# Extract second digit (right half, columns 28-55)
second_digit_array = img_array[:, 28:56]
second_digit_pil = Image.fromarray(second_digit_array.astype(np.uint8))

# Save first digit to its concept folder
concept_dir_first = os.path.join(project_root, output_base, str(first_digit))
img_path_first = os.path.join(concept_dir_first, f"{img_count[first_digit]:06d}.png")
first_digit_pil.save(img_path_first)
img_count[first_digit] += 1

# Save second digit to its concept folder
concept_dir_second = os.path.join(project_root, output_base, str(second_digit))
img_path_second = os.path.join(concept_dir_second, f"{img_count[second_digit]:06d}.png")
second_digit_pil.save(img_path_second)
img_count[second_digit] += 1

if idx % 5000 == 0:
print(f"Processed {idx} images, counts: {img_count}")

if __name__ == "__main__":
generate_mnist_concepts()
print("Concept generation complete!")
139 changes: 114 additions & 25 deletions rsseval/rss/utils/tcav/tcav/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from collections import OrderedDict
from torch.utils.data import DataLoader

import sys
project_root = "/home/park1119/rsbench-code/rsseval/rss"

if project_root not in sys.path:
sys.path.insert(0, project_root)

torch.multiprocessing.set_sharing_strategy("file_system")

from datasets.boia import BOIA
Expand Down Expand Up @@ -70,7 +76,7 @@ def validate(
is_boia = True

if dataset_name in ["shortmnist", "clipshortmnist"]:
extract_layer = "conv2" # conv1, conv2, fc1, fc2
extract_layer = "fc2" # conv1, conv2, fc1, fc2
is_boia = False
if dataset_name in ["boia", "clipboia"]:
extract_layer = "fc1" # fc1, fc2, fc3, fc4
Expand Down Expand Up @@ -153,27 +159,97 @@ def get_dataset(datasetname, args):


def setup():
# WARNING: BATCH SIZE MUST BE 1 FOR TCAV EVALUATION
# Set validate=0 because validating?

# CHANGE HERE

# args = Namespace(
# backbone="neural", #
# preprocess=0,
# finetuning=0,
# batch_size=128,
# n_epochs=40,
# validate=0, # validate=0 means evaluate on test, not validation (train.py)
# dataset="shortmnist",
# lr=0.001,
# exp_decay=0.99,
# warmup_steps=0, # this is default even when it was first trained
# wandb=None,
# task="addition",
# model="mnistnn",
# c_sup=0,
# which_c=[-1],
# joint=False # True only for CLIP
# )

# args = Namespace(
# backbone="neural", #
# preprocess=0,
# finetuning=0,
# batch_size=1,
# n_epochs=40,
# validate=0,
# dataset="clipshortmnist",
# lr=0.001,
# exp_decay=0.99,
# weight_decay=0.01,
# warmup_steps=0,
# wandb=None,
# task="addition",
# model="mnistnn", #
# c_sup=0,
# which_c=[-1],
# joint=True,
# )


# EDIT HERE FOR DIFFERENT SETTINGS
args = Namespace(
backbone="neural", # "conceptizer",
backbone="neural", #
preprocess=0,
finetuning=0,
batch_size=1,
n_epochs=20,
validate=1,
dataset="clipsddoia",
batch_size=32,
n_epochs=40,
validate=0, # validate=0 means evaluate on test, not validation (train.py)
dataset="sddoia",
lr=0.001,
exp_decay=0.99,
warmup_steps=1,
weight_decay=0.0001,
warmup_steps=0, # this is default even when it was first trained
wandb=None,
task="boia",
boia_model="ce",
boia_model="ce", # ce (cross-entropy) or bce (binary cross-entropy?) keep it at ce
model="sddoiann",
c_sup=0,
which_c=-1,
joint=True,
boia_ood_knowledge=False,
which_c=[-1],
joint=False, # True only for CLIP
boia_ood_knowledge=False, # case for nesy models to handle ambulance
)

# args = Namespace(
# backbone="neural", # "conceptizer",
# preprocess=0,
# finetuning=0,
# batch_size=1,
# n_epochs=20,
# validate=0,
# dataset="clipsddoia",
# lr=0.001,
# exp_decay=0.99,
# warmup_steps=1,
# wandb=None,
# task="boia",
# boia_model="ce",
# model="sddoiann",
# c_sup=0,
# which_c=-1,
# joint=True,
# boia_ood_knowledge=False,
# )



# get dataset
dataset = get_dataset(args.dataset, args)
# get model
Expand Down Expand Up @@ -232,12 +308,13 @@ def mnist_tcav_setup():
]

tmp_concept_dict = {}
for dirname in os.listdir("../data/concepts"):
fullpath = os.path.join("../data/concepts", dirname)
for dirname in os.listdir("../../../data/concepts/mnist"):
fullpath = os.path.join("../../../data/concepts/mnist", dirname)
if os.path.isdir(fullpath):
tmp_concept_dict[dirname] = data_loader(fullpath, args.dataset)

concept_dict = OrderedDict()
print(tmp_concept_dict.keys())
for c in concepts_order:
concept_dict[c] = tmp_concept_dict[c]

Expand All @@ -263,9 +340,9 @@ def kand_tcav_setup(is_clip=False):

tmp_concept_dict = {}

lmao_name = "../data/kand-tcav/"
lmao_name = "../../../data/concepts/kand-tcav/"
if is_clip:
lmao_name = "../data/kand-tcav-clip/"
lmao_name = "../../../data/concepts/kand-tcav-clip/"

for dirname in os.listdir(lmao_name):
fullpath = os.path.join(lmao_name, dirname)
Expand Down Expand Up @@ -320,8 +397,8 @@ def boia_tcav_setup():
]

tmp_concept_dict = {}
for dirname in os.listdir("../data/boia-preprocess-full/concepts/"):
fullpath = os.path.join("../data/boia-preprocess-full/concepts/", dirname)
for dirname in os.listdir("../../../data/concepts/boia-preprocess-full/"):
fullpath = os.path.join("../../../data/concepts/boia-preprocess-full/", dirname)
if os.path.isdir(fullpath):
tmp_concept_dict[dirname] = data_loader(fullpath, args.dataset)

Expand Down Expand Up @@ -378,8 +455,8 @@ def sddoia_tcav_setup(full=False):
folder_suffix = "-preprocess-full"

tmp_concept_dict = {}
for dirname in os.listdir(f"../data/sddoia{folder_suffix}/concepts"):
fullpath = os.path.join(f"../data/sddoia{folder_suffix}/concepts", dirname)
for dirname in os.listdir(f"../../../data/concepts/sddoia{folder_suffix}"):
fullpath = os.path.join(f"../../../data/concepts/sddoia{folder_suffix}", dirname)
if os.path.isdir(fullpath):
tmp_concept_dict[dirname] = data_loader(fullpath, args.dataset)

Expand Down Expand Up @@ -407,8 +484,8 @@ def xor_tcav_setup():
]

tmp_concept_dict = {}
for dirname in os.listdir("../data/xor/concepts"):
fullpath = os.path.join("../data/xor/concepts", dirname)
for dirname in os.listdir("../../../data/concepts/xor"):
fullpath = os.path.join("../../../data/concepts/xor", dirname)
if os.path.isdir(fullpath):
for i in range(4):

Expand Down Expand Up @@ -524,8 +601,8 @@ def mnmath_tcav_setup():
"xxxxxxx9",
]
tmp_concept_dict = {}
for dirname in os.listdir("../data/concepts"):
fullpath = os.path.join("../data/concepts", dirname)
for dirname in os.listdir("../../../data/concepts/mnmath"):
fullpath = os.path.join("../../../data/concepts/mnmath", dirname)
if os.path.isdir(fullpath):
for i in range(8):
concept_name = ''
Expand Down Expand Up @@ -566,7 +643,9 @@ def mnmath_tcav_setup():
# get everything
args, dataset, model = setup()

seeds = [123, 456, 789, 1011, 1213]
# CHANGE HERE
seeds = [123, 456, 789, 1011, 1213, 1415, 1617, 1819, 2021, 2223]
#seeds = [1415, 1617, 1819, 2021, 2223]
model_path = f"best_model_{args.dataset}_{args.model}"
sddoia_full = ""
to_add = "" # "_padd_random"
Expand All @@ -577,7 +656,9 @@ def mnmath_tcav_setup():

print("Doing seed", seed)

current_model_path = f"{model_path}_{seed}.pth"
# CHANGE HERE
current_model_path = f"../../../jobs/{args.dataset}_nn_run_nocsup/{model_path}_{seed}.pth"
print(f"Loading model from {current_model_path}")

if not os.path.exists(current_model_path):
print(f"{current_model_path} is missing...")
Expand Down Expand Up @@ -608,6 +689,14 @@ def mnmath_tcav_setup():
to_add = "_full"
validloader, class_dict, concept_dict = sddoia_tcav_setup()

# added by daniel: reset validloader with batch_size=1 and shuffle=False
validloader = torch.utils.data.DataLoader(
validloader.dataset,
batch_size=1,
shuffle=False, # Disable shuffling for consistent ordering
num_workers=4,
)

validate(
model,
args.dataset,
Expand Down