Copy
# Import the Rebase Toolkit
import rebase
rocket-launch
# 1. Start a W&B Run
run = wandb.init(project="cat-classification", notes="", tags=["baseline", "paper1"])
# 2. Capture a dictionary of hyperparameters
wandb.config = {"epochs": 100, "learning_rate": 0.001, "batch_size": 128}
# Set up model and data
model, dataloader = get_model(), get_data()
for epoch in range(wandb.config.epochs):
for batch in dataloader:
loss, accuracy = model.training_step()
# 3. Log metrics inside your training loop to visualize
# model performance
wandb.log({"accuracy": accuracy, "loss": loss})
# 4. Log an artifact to W&B
wandb.log_artifact(model)
# Optional: save model at the end
model.to_onnx()
wandb.save("model.onnx")