diff --git a/examples/deepscaler/train_deepscaler_nb.py b/examples/deepscaler/train_deepscaler_nb.py index f15ade579..ae7ae610a 100644 --- a/examples/deepscaler/train_deepscaler_nb.py +++ b/examples/deepscaler/train_deepscaler_nb.py @@ -17,6 +17,7 @@ import optax from orbax import checkpoint as ocp import qwix +import wandb # ====== Logging Configuration ====== # 1. Force absl to use python logging @@ -286,6 +287,30 @@ ) +wandb.init( + project="deepscaler", + reinit=True, + config={ + "batch_size": BATCH_SIZE, + "mini_batch_size": MINI_BATCH_SIZE, + "learning_rate": LEARNING_RATE, + "B1": B1, + "B2": B2, + "WARMUP_STEPS": WARMUP_STEPS, + "weight_decay": WEIGHT_DECAY, + "num_steps": MAX_STEPS, + "num_generations": NUM_GENERATIONS, + "beta": BETA, + "epsilon": EPSILON, + "epsilon_high": EPSILON_HIGH, + "max_response_length": MAX_RESPONSE_LENGTH, + "temperature": TEMPERATURE, + "top_p": TOP_P, + "top_k": TOP_K, + }, +) + + def create_datasets( train_ds_path: str = DEEPSCALER_DATA_PATH, test_ds_path: str = AIME_2024_DATA_PATH,