From 0339887281f3e566506c3f665c91787246a4ec17 Mon Sep 17 00:00:00 2001 From: us3r247 Date: Sun, 13 Apr 2025 09:59:34 +0530 Subject: [PATCH 1/7] added dataset support for cifar100 & caltech256. Modified train.py to support pre-training configs, updated log_metrics(), added checkpoint save for pretrain datasets --- .gitignore | 3 + caltech_data.py | 103 +++++++++++++++++++++++++++++++++ cifar_data.py | 36 ++++++++++++ train.py | 150 ++++++++++++++++++++++++++++++++---------------- 4 files changed, 241 insertions(+), 51 deletions(-) create mode 100644 .gitignore create mode 100644 caltech_data.py create mode 100644 cifar_data.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9b4545f --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/scratch.ipynb +/__pycache__ +/data \ No newline at end of file diff --git a/caltech_data.py b/caltech_data.py new file mode 100644 index 0000000..46b91a1 --- /dev/null +++ b/caltech_data.py @@ -0,0 +1,103 @@ +import torch +from torchvision.transforms import v2 +from torch.utils.data import random_split,DataLoader +import os +import urllib.request +import tarfile +from typing import Optional, Callable, Any +from tqdm import tqdm +from torchvision.datasets import ImageFolder +from torchvision.datasets.folder import default_loader + +# PATH_TO_CALTECH256 = "path/to/download/caltech256" TODO: change this while trainign... + +PATH_TO_CALTECH256 = "/mnt/769EC2439EC1FB9D/vsc_projs/caltech256" + + + +class CustomCaltech256(ImageFolder): + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + download: bool = False, + custom_url: Optional[str] = None, + filename: str = "256_ObjectCategories.tar", + ): + self.root = os.path.expanduser(root) + self.custom_url = custom_url or ( + "https://data.caltech.edu/records/nyy15-4j048/files/256_ObjectCategories.tar?download=1" + ) + self.filename = filename + self.filepath = os.path.join(self.root, self.filename) + self.data_folder = os.path.join(self.root, "256_ObjectCategories") + + if download: + self._download() + + super().__init__( + root=self.data_folder, + transform=transform, + target_transform=target_transform, + loader=loader + ) + + def _download(self): + if os.path.isdir(self.data_folder): + print("✅ Caltech-256 already extracted.") + return + + os.makedirs(self.root, exist_ok=True) + + if not os.path.isfile(self.filepath): + print("⬇️ Downloading Caltech-256...") + + def progress_hook(t): + last_b = [0] + def update_to(block_num=1, block_size=1, total_size=None): + if total_size is not None: + t.total = total_size + downloaded = block_num * block_size + t.update(downloaded - last_b[0]) + last_b[0] = downloaded + return update_to + + with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=self.filename) as t: + urllib.request.urlretrieve(self.custom_url, self.filepath, reporthook=progress_hook(t)) + + print("✅ Download complete.") + + print("📦 Extracting Caltech-256...") + with tarfile.open(self.filepath, "r") as tar: + tar.extractall(path=self.root) + print("✅ Extraction complete.") + + +transforms = v2.Compose([ + v2.PILToTensor(), + v2.ToDtype(torch.float32,scale=True), + v2.Resize(256), + v2.CenterCrop(224), + v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), #imagenet norm values... + + # augments if to be added + v2.AutoAugment() + +]) + +caltech256 = CustomCaltech256( + root=PATH_TO_CALTECH256, + transform=transforms, + download=True, +) +train_data,val_data = random_split(caltech256,[27607,3000]) # ~90/10 + +def get_caltech_train_loader(batch_size,shuffle=True,num_workers=4): + return DataLoader(train_data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,pin_memory=True) + +def get_caltech_val_loader(batch_size,shuffle=True,num_workers=4): + return DataLoader(val_data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,pin_memory=True) + + diff --git a/cifar_data.py b/cifar_data.py new file mode 100644 index 0000000..e3964f2 --- /dev/null +++ b/cifar_data.py @@ -0,0 +1,36 @@ + +import torch +from torchvision import datasets +from torchvision.transforms import v2 +from torch.utils.data import random_split,DataLoader + + +# PATH_TO_CIFAR100 = "path/to/download/cifar100" TODO: change this while training + +PATH_TO_CIFAR100 = "/mnt/769EC2439EC1FB9D/vsc_projs/cifar100" + +transforms = v2.Compose([ + v2.PILToTensor(), + v2.ToDtype(torch.float32,scale=True), + v2.Resize(224), + v2.Normalize(mean=(0.5071, 0.4867, 0.4408),std=(0.2675, 0.2565, 0.2761)), #cifar100 norm values... + + # augments if to be added + v2.AutoAugment() + +]) + + +cifar100 = datasets.CIFAR100(root=PATH_TO_CIFAR100,download=True,transform=transforms) +train_data, val_data = random_split(dataset=cifar100, lengths=[45000,5000]) #90/10 + + +def get_cifar_train_loader(batch_size,shuffle=True,num_workers=4): + return DataLoader(dataset=train_data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,pin_memory=True) + +def get_cifar_val_loader(batch_size,shuffle=True,num_workers=4): + return DataLoader(dataset=val_data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,pin_memory=True) + + + + diff --git a/train.py b/train.py index 33c6a2b..88a646d 100644 --- a/train.py +++ b/train.py @@ -13,6 +13,8 @@ # Import custom modules from model import get_model, MODEL_CONFIG, EFFICIENT_MODEL_CONFIG from data import get_loaders +from caltech_data import get_caltech_train_loader,get_caltech_val_loader +from cifar_data import get_cifar_train_loader,get_cifar_val_loader # Training Configuration TRAIN_CONFIG = { @@ -91,6 +93,13 @@ def save_checkpoint(model, optimizer, scheduler, epoch, accuracy, filename): torch.save(state, filename) print(f"Checkpoint saved to {filename}") +def save_pretrain_checkpoint(model, dataset_name, output_dir='checkpoints'): + os.makedirs(output_dir, exist_ok=True) + path = os.path.join(output_dir, f'{dataset_name}_checkpoint.pth') + torch.save(model.state_dict(), path) + print(f"Saved checkpoint for {dataset_name} at: {path}") + + def load_checkpoint(model, optimizer=None, scheduler=None, filename=None): """Load checkpoint from file""" if not os.path.isfile(filename): @@ -131,33 +140,63 @@ def create_scheduler(optimizer, num_epochs, steps_per_epoch, warmup_epochs=5, mi min_lr=min_lr ) -def create_csv_logger(output_dir, model_name): - """Create a CSV logger to save training metrics""" - os.makedirs(output_dir, exist_ok=True) - csv_path = os.path.join(output_dir, f"{model_name}_training_log.csv") - - # Create CSV file and write header - with open(csv_path, 'w', newline='') as f: + +def log_metrics(dataset_name, epoch, train_loss, train_acc1, train_acc5, val_loss, val_acc1, val_acc5, log_dir="logs"): + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"{dataset_name}_log.csv") + write_header = not os.path.exists(log_file) + + with open(log_file, mode='a', newline='') as f: writer = csv.writer(f) + if write_header: + writer.writerow([ + "epoch", + "train_loss", "train_acc1", "train_acc5", + "val_loss", "val_acc1", "val_acc5" + ]) writer.writerow([ - 'epoch', 'lr', - 'train_loss', 'train_acc1', 'train_acc5', - 'val_loss', 'val_acc1', 'val_acc5', - 'best_acc', 'time' + epoch + 1, + train_loss, train_acc1, train_acc5, + val_loss, val_acc1, val_acc5 ]) + + +def replace_head(model, num_classes): + model.head = nn.Linear(model.head.in_features, num_classes) + return model + + +def pretrain_on_dataset(model, train_loader, val_loader, num_classes, args,dataset_name): + model = replace_head(model, num_classes) + model.to(args.device) + + optimizer = create_optimizer(model, args.lr, args.weight_decay) + scheduler = create_scheduler( + optimizer, + num_epochs=args.epochs, + steps_per_epoch=len(train_loader), + warmup_epochs=args.warmup_epochs, + min_lr=args.min_lr + ) + criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing) + + print(f"\n=== Pretraining on {dataset_name} ===") + best_acc = 0.0 + for epoch in range(args.epochs): + train_loss, train_acc1, train_acc5 = train_one_epoch( + model, train_loader, criterion, optimizer, scheduler, epoch, args.device + ) + val_loss, val_acc1, val_acc5 = validate(model, val_loader, criterion, args.device) + + best_acc = max(val_acc1, best_acc) + print(f"[{dataset_name}] Epoch {epoch+1}: Acc@1={val_acc1:.2f}% | Best={best_acc:.2f}%") + log_metrics(dataset_name, epoch, train_loss, train_acc1,train_acc5, val_loss, val_acc1,val_acc5) - return csv_path + save_pretrain_checkpoint(model, dataset_name) + + return model + -def log_metrics(csv_path, epoch, lr, train_metrics, val_metrics, best_acc, epoch_time): - """Log metrics to CSV file""" - with open(csv_path, 'a', newline='') as f: - writer = csv.writer(f) - writer.writerow([ - epoch, lr, - train_metrics[0], train_metrics[1], train_metrics[2], - val_metrics[0], val_metrics[1], val_metrics[2], - best_acc, epoch_time - ]) ##################################### # Training and Evaluation Functions @@ -269,9 +308,31 @@ def main(args): os.makedirs(output_dir, exist_ok=True) # Create model - model_name = 'swin_t_efficient' if args.efficient else 'swin_t' - print(f"Creating {'efficient' if args.efficient else 'standard'} Swin-T model") - model = get_model(model_name='swin_t', efficient=args.efficient) + model = get_model(model_name='swin_t', efficient=args.efficient, num_classes=1000) + + # === Pretraining Stage: CIFAR === + if args.pretrain_cifar and not args.skip_pretrain: + print("creating dataloaders for cifar100...") + cifar_train_loader = get_cifar_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) + cifar_val_loader = get_cifar_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) + model = pretrain_on_dataset(model, cifar_train_loader, cifar_val_loader, num_classes=100, args=args, dataset_name='cifar100') + + # === Pretraining Stage: Caltech === + if args.pretrain_caltech and not args.skip_pretrain: + print("creating dataloaders for caltech256...") + caltech_train_loader = get_caltech_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) + caltech_val_loader = get_caltech_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) + model = pretrain_on_dataset(model, caltech_train_loader, caltech_val_loader, num_classes=257, args=args, dataset_name='caltech256') + + # === Final Training Stage === + print("creating dataloaders for tinyimagenet...") + train_loader, val_loader, mixup_fn = get_loaders( + batch_size=args.batch_size, + num_workers=args.workers, + img_size=MODEL_CONFIG["img_size"], + use_mixup=args.mixup + ) + model = replace_head(model, num_classes=200) model = model.to(args.device) # Print model information @@ -281,14 +342,6 @@ def main(args): # Create optimizer optimizer = create_optimizer(model, args.lr, args.weight_decay) - # Create data loaders - print("Creating data loaders") - train_loader, val_loader, mixup_fn = get_loaders( - batch_size=args.batch_size, - num_workers=args.workers, - img_size=MODEL_CONFIG["img_size"], - use_mixup=args.mixup - ) # Create scheduler scheduler = create_scheduler( @@ -307,9 +360,6 @@ def main(args): # Use label smoothing cross entropy loss criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing) - # Create CSV logger - csv_path = create_csv_logger(output_dir, model_name) - # Optionally resume from checkpoint start_epoch = 0 best_acc = 0.0 @@ -332,7 +382,6 @@ def main(args): print(f"Learning rate: {args.lr}") print(f"Weight decay: {args.weight_decay}") print(f"Using mixup: {args.mixup}") - print(f"Training progress will be saved to: {csv_path}") # Training loop for epoch in range(start_epoch, args.epochs): @@ -340,8 +389,7 @@ def main(args): # Train for one epoch train_loss, train_acc1, train_acc5 = train_one_epoch( - model, train_loader, criterion, optimizer, scheduler, epoch, args.device, mixup_fn - ) + model, train_loader, criterion, optimizer, scheduler, epoch, args.device, mixup_fn) # Evaluate on validation set val_loss, val_acc1, val_acc5 = validate(model, val_loader, criterion, args.device) @@ -355,15 +403,8 @@ def main(args): # Log metrics to CSV lr = scheduler.get_last_lr()[0] - log_metrics( - csv_path, - epoch + 1, - lr, - (train_loss, train_acc1, train_acc5), - (val_loss, val_acc1, val_acc5), - best_acc, - epoch_time - ) + log_metrics("main", epoch, train_loss, train_acc1, val_loss, val_acc1, log_dir=args.log_dir) + # Print epoch summary print(f"Epoch {epoch+1}/{args.epochs} | Time: {epoch_time:.2f}s") @@ -384,9 +425,8 @@ def main(args): model, optimizer, scheduler, epoch + 1, val_acc1, os.path.join(output_dir, 'model_best.pth') ) - + print(f"Training complete. Best accuracy: {best_acc:.2f}%") - print(f"Training log saved to: {csv_path}") def parse_args(): parser = argparse.ArgumentParser(description='Swin Transformer for Tiny ImageNet') @@ -394,6 +434,11 @@ def parse_args(): # Model parameters parser.add_argument('--efficient', action='store_true', help='Use efficient model variant') + # pre-training... + parser.add_argument('--pretrain_cifar', action='store_true', help='Pretrain on CIFAR first') + parser.add_argument('--pretrain_caltech', action='store_true', help='Pretrain on Caltech after CIFAR') + parser.add_argument('--skip-pretrain', action='store_true', help='Skip all pretraining stages') + # Training parameters parser.add_argument('--batch-size', type=int, default=TRAIN_CONFIG['batch_size'], help='Batch size') parser.add_argument('--epochs', type=int, default=TRAIN_CONFIG['epochs'], help='Number of epochs') @@ -414,6 +459,9 @@ def parse_args(): parser.add_argument('--save-interval', type=int, default=TRAIN_CONFIG['save_interval'], help='Save checkpoint every N epochs') parser.add_argument('--resume', default='', help='Resume from checkpoint') + # logs + parser.add_argument('--log-dir', default='logs', type=str, help='Directory to save all training logs') + # Misc parser.add_argument('--seed', type=int, default=42, help='Random seed') parser.add_argument('--evaluate', action='store_true', help='Evaluate only') @@ -422,4 +470,4 @@ def parse_args(): if __name__ == '__main__': args = parse_args() - main(args) \ No newline at end of file + main(args) \ No newline at end of file From 3c3931d4abd6b0444b140c089bdf230ed842e457 Mon Sep 17 00:00:00 2001 From: Yash-Agarwal-BITS Date: Tue, 15 Apr 2025 01:16:11 +0530 Subject: [PATCH 2/7] Update train.py added graphs --- train.py | 757 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 545 insertions(+), 212 deletions(-) diff --git a/train.py b/train.py index 88a646d..618a6c2 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,10 @@ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy import csv from tqdm import tqdm +import pandas as pd # Added for reading logs +import matplotlib.pyplot as plt # Added for plotting +import seaborn as sns # Added for confusion matrix styling +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay # Added for CM # Import custom modules from model import get_model, MODEL_CONFIG, EFFICIENT_MODEL_CONFIG @@ -19,7 +23,7 @@ # Training Configuration TRAIN_CONFIG = { "batch_size": 128, - "epochs": 100, + "epochs": 100, # Reduced for faster demonstration if needed "learning_rate": 5e-4, "min_lr": 5e-6, "weight_decay": 0.05, @@ -28,7 +32,8 @@ "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), "output_dir": "output", "log_interval": 20, - "save_interval": 10, + "save_interval": 10, # Reduced for faster demonstration + "log_dir": "logs", # Added log dir to config } ##################################### @@ -50,13 +55,19 @@ def update(self, val, n=1): self.val = val self.sum += val * n self.count += n - self.avg = self.sum / self.count + if self.count > 0: + self.avg = self.sum / self.count + else: + self.avg = 0 + def accuracy(output, target, topk=(1,)): """Compute the accuracy over the k top predictions""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) + if batch_size == 0: + return [torch.tensor(0.0) for _ in topk] _, pred = output.topk(maxk, 1, True, True) pred = pred.t() @@ -68,20 +79,26 @@ def accuracy(output, target, topk=(1,)): res.append(correct_k.mul_(100.0 / batch_size)) return res -def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr): +def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr, base_lr): """Create a cosine learning rate scheduler with warmup""" def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) - + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) - return max(min_lr / TRAIN_CONFIG["learning_rate"], cosine_decay) - + # Ensure the final learning rate doesn't go below min_lr + return max(min_lr / base_lr, cosine_decay) + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + def save_checkpoint(model, optimizer, scheduler, epoch, accuracy, filename): """Save model checkpoint""" + if not os.path.exists(os.path.dirname(filename)): + os.makedirs(os.path.dirname(filename), exist_ok=True) + print(f"Created directory: {os.path.dirname(filename)}") + state = { 'epoch': epoch, 'model': model.state_dict(), @@ -89,35 +106,60 @@ def save_checkpoint(model, optimizer, scheduler, epoch, accuracy, filename): 'scheduler': scheduler.state_dict() if scheduler else None, 'accuracy': accuracy, } - + torch.save(state, filename) print(f"Checkpoint saved to {filename}") def save_pretrain_checkpoint(model, dataset_name, output_dir='checkpoints'): + """Save model checkpoint after pretraining""" os.makedirs(output_dir, exist_ok=True) path = os.path.join(output_dir, f'{dataset_name}_checkpoint.pth') + # Save only the model state dict for pretraining checkpoints torch.save(model.state_dict(), path) - print(f"Saved checkpoint for {dataset_name} at: {path}") + print(f"Saved pretraining checkpoint for {dataset_name} at: {path}") -def load_checkpoint(model, optimizer=None, scheduler=None, filename=None): +def load_checkpoint(model, optimizer=None, scheduler=None, filename=None, load_optimizer_scheduler=True): """Load checkpoint from file""" - if not os.path.isfile(filename): + if not filename or not os.path.isfile(filename): print(f"No checkpoint found at {filename}") return 0, 0.0 - + print(f"Loading checkpoint from {filename}") checkpoint = torch.load(filename, map_location='cpu') - - model.load_state_dict(checkpoint['model']) - - if optimizer is not None and 'optimizer' in checkpoint: - optimizer.load_state_dict(checkpoint['optimizer']) - - if scheduler is not None and 'scheduler' in checkpoint and checkpoint['scheduler'] is not None: - scheduler.load_state_dict(checkpoint['scheduler']) - - return checkpoint['epoch'], checkpoint['accuracy'] + + # Handle both full checkpoints and state_dict-only checkpoints + if 'model' in checkpoint: + model.load_state_dict(checkpoint['model']) + else: + # Assume it's just a state_dict + model.load_state_dict(checkpoint) + # If only state_dict loaded, cannot resume optimizer/scheduler/epoch + print("Loaded model state_dict only. Cannot resume optimizer, scheduler, or epoch.") + return 0, checkpoint.get('accuracy', 0.0) # Return 0 epoch, try to get accuracy + + epoch = checkpoint.get('epoch', 0) + accuracy = checkpoint.get('accuracy', 0.0) + + if load_optimizer_scheduler: + if optimizer is not None and 'optimizer' in checkpoint: + try: + optimizer.load_state_dict(checkpoint['optimizer']) + except Exception as e: + print(f"Could not load optimizer state: {e}. Continuing without loading optimizer.") + + + if scheduler is not None and 'scheduler' in checkpoint and checkpoint['scheduler'] is not None: + try: + scheduler.load_state_dict(checkpoint['scheduler']) + except Exception as e: + print(f"Could not load scheduler state: {e}. Continuing without loading scheduler.") + + + else: + print("Skipping loading optimizer and scheduler state.") + + return epoch, accuracy def create_optimizer(model, lr, weight_decay): """Create optimizer for model""" @@ -128,20 +170,22 @@ def create_optimizer(model, lr, weight_decay): betas=(0.9, 0.999) ) -def create_scheduler(optimizer, num_epochs, steps_per_epoch, warmup_epochs=5, min_lr=5e-6): +def create_scheduler(optimizer, num_epochs, steps_per_epoch, base_lr, warmup_epochs=5, min_lr=5e-6): """Create learning rate scheduler""" num_training_steps = num_epochs * steps_per_epoch num_warmup_steps = warmup_epochs * steps_per_epoch - + return get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, - min_lr=min_lr + min_lr=min_lr, + base_lr=base_lr # Pass base_lr here ) def log_metrics(dataset_name, epoch, train_loss, train_acc1, train_acc5, val_loss, val_acc1, val_acc5, log_dir="logs"): + """Logs training and validation metrics to a CSV file.""" os.makedirs(log_dir, exist_ok=True) log_file = os.path.join(log_dir, f"{dataset_name}_log.csv") write_header = not os.path.exists(log_file) @@ -150,98 +194,264 @@ def log_metrics(dataset_name, epoch, train_loss, train_acc1, train_acc5, val_los writer = csv.writer(f) if write_header: writer.writerow([ - "epoch", - "train_loss", "train_acc1", "train_acc5", + "epoch", + "train_loss", "train_acc1", "train_acc5", "val_loss", "val_acc1", "val_acc5" ]) writer.writerow([ - epoch + 1, - train_loss, train_acc1, train_acc5, - val_loss, val_acc1, val_acc5 + epoch + 1, # Log 1-based epoch + f"{train_loss:.4f}" if train_loss is not None else "N/A", + f"{train_acc1:.2f}" if train_acc1 is not None else "N/A", + f"{train_acc5:.2f}" if train_acc5 is not None else "N/A", + f"{val_loss:.4f}" if val_loss is not None else "N/A", + f"{val_acc1:.2f}" if val_acc1 is not None else "N/A", + f"{val_acc5:.2f}" if val_acc5 is not None else "N/A" ]) def replace_head(model, num_classes): - model.head = nn.Linear(model.head.in_features, num_classes) + """Replaces the classification head of the model.""" + in_features = 0 + if hasattr(model, 'head') and hasattr(model.head, 'in_features'): + in_features = model.head.in_features + elif hasattr(model, 'fc') and hasattr(model.fc, 'in_features'): # common alternative name + in_features = model.fc.in_features + elif hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear): # Another common name + in_features = model.classifier.in_features + elif hasattr(model, 'num_features'): # Timm models often have this property + in_features = model.num_features + else: + raise AttributeError("Cannot determine the input features of the model's classification head. Tried 'head', 'fc', 'classifier'.") + + # Replace the head + if hasattr(model, 'head'): + model.head = nn.Linear(in_features, num_classes) + elif hasattr(model, 'fc'): + model.fc = nn.Linear(in_features, num_classes) + elif hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear): + model.classifier = nn.Linear(in_features, num_classes) + else: + # Fallback for models where head is not explicitly named 'head', 'fc' or 'classifier' + # This might need adjustment based on the specific architecture if get_model returns something unusual + print("Warning: Replacing head using a generic approach based on Timm's num_features. Ensure this is correct for the model.") + model.head = nn.Linear(in_features, num_classes) # Assume we can add a 'head' attribute + + print(f"Replaced model head with a new one for {num_classes} classes.") return model -def pretrain_on_dataset(model, train_loader, val_loader, num_classes, args,dataset_name): - model = replace_head(model, num_classes) +##################################### +# Plotting Functions +##################################### + +def plot_metrics(log_file, output_dir, dataset_name): + """Plots loss and accuracy curves from a log file.""" + if not os.path.exists(log_file): + print(f"Log file not found: {log_file}. Skipping plotting.") + return + + try: + df = pd.read_csv(log_file) + except Exception as e: + print(f"Error reading log file {log_file}: {e}. Skipping plotting.") + return + + if df.empty: + print(f"Log file {log_file} is empty. Skipping plotting.") + return + + plt.style.use('seaborn-v0_8-grid') # Use a nice style + fig, ax1 = plt.subplots(figsize=(12, 6)) + + # Plot Loss + color = 'tab:red' + ax1.set_xlabel('Epoch') + ax1.set_ylabel('Loss', color=color) + if 'train_loss' in df.columns and pd.to_numeric(df['train_loss'], errors='coerce').notna().any(): + ax1.plot(df['epoch'], pd.to_numeric(df['train_loss'], errors='coerce'), label='Train Loss', color=color, linestyle='--') + if 'val_loss' in df.columns and pd.to_numeric(df['val_loss'], errors='coerce').notna().any(): + ax1.plot(df['epoch'], pd.to_numeric(df['val_loss'], errors='coerce'), label='Validation Loss', color=color) + ax1.tick_params(axis='y', labelcolor=color) + ax1.legend(loc='upper left') + + # Plot Accuracy + ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis + color = 'tab:blue' + ax2.set_ylabel('Accuracy (%)', color=color) + if 'train_acc1' in df.columns and pd.to_numeric(df['train_acc1'], errors='coerce').notna().any(): + ax2.plot(df['epoch'], pd.to_numeric(df['train_acc1'], errors='coerce'), label='Train Acc@1', color=color, linestyle='--') + if 'val_acc1' in df.columns and pd.to_numeric(df['val_acc1'], errors='coerce').notna().any(): + ax2.plot(df['epoch'], pd.to_numeric(df['val_acc1'], errors='coerce'), label='Validation Acc@1', color=color) + ax2.tick_params(axis='y', labelcolor=color) + ax2.legend(loc='lower left') + + plt.title(f'{dataset_name} - Training & Validation Metrics') + fig.tight_layout() # otherwise the right y-label is slightly clipped + + # Save plot + plot_filename = os.path.join(output_dir, f"{dataset_name}_metrics_plot.png") + plt.savefig(plot_filename) + print(f"Metrics plot saved to {plot_filename}") + plt.close(fig) # Close the figure to free memory + +def plot_confusion_matrix(all_preds, all_targets, num_classes, output_dir, dataset_name): + """Computes and plots the confusion matrix.""" + if all_preds is None or all_targets is None: + print(f"No prediction data available for {dataset_name}. Skipping confusion matrix.") + return + + cm = confusion_matrix(all_targets, all_preds, labels=np.arange(num_classes)) + + # Determine figure size based on number of classes + figsize = max(8, num_classes // 5) # Adjust divisor as needed + + fig, ax = plt.subplots(figsize=(figsize, figsize)) + disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=np.arange(num_classes)) + + # Determine whether to show values based on matrix size + show_values = num_classes <= 30 # Only show values for smaller matrices + + disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation='vertical', values_format='d' if show_values else None) # Only show numbers if show_values is True + + plt.title(f'{dataset_name} - Confusion Matrix') + plt.tight_layout() + + # Save plot + cm_filename = os.path.join(output_dir, f"{dataset_name}_confusion_matrix.png") + plt.savefig(cm_filename) + print(f"Confusion matrix saved to {cm_filename}") + plt.close(fig) # Close the figure + + +##################################### +# Pretraining Function +##################################### + +def pretrain_on_dataset(model, train_loader, val_loader, num_classes, args, dataset_name, output_dir, log_dir): + """Pretrains the model on a given dataset (CIFAR or Caltech).""" + print(f"\n=== Pretraining Stage: {dataset_name} ===") + model = replace_head(model, num_classes) # Ensure head matches dataset model.to(args.device) optimizer = create_optimizer(model, args.lr, args.weight_decay) + steps_per_epoch = len(train_loader) scheduler = create_scheduler( optimizer, - num_epochs=args.epochs, - steps_per_epoch=len(train_loader), + num_epochs=args.epochs, # Use main epoch count for pretraining? Or specific pretrain epochs? Using main for now. + steps_per_epoch=steps_per_epoch, + base_lr=args.lr, warmup_epochs=args.warmup_epochs, min_lr=args.min_lr ) - criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing) + # Use label smoothing for pretraining as well + criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing).to(args.device) - print(f"\n=== Pretraining on {dataset_name} ===") best_acc = 0.0 - for epoch in range(args.epochs): + for epoch in range(args.epochs): # Use same number of epochs as main training? + print(f"\n--- {dataset_name} Epoch {epoch+1}/{args.epochs} ---") train_loss, train_acc1, train_acc5 = train_one_epoch( - model, train_loader, criterion, optimizer, scheduler, epoch, args.device + model, train_loader, criterion, optimizer, scheduler, epoch, args.device, args=args + ) + val_loss, val_acc1, val_acc5, _, _ = validate( # Get metrics only during training loop + model, val_loader, criterion, args.device, return_preds_targets=False ) - val_loss, val_acc1, val_acc5 = validate(model, val_loader, criterion, args.device) - best_acc = max(val_acc1, best_acc) - print(f"[{dataset_name}] Epoch {epoch+1}: Acc@1={val_acc1:.2f}% | Best={best_acc:.2f}%") - log_metrics(dataset_name, epoch, train_loss, train_acc1,train_acc5, val_loss, val_acc1,val_acc5) - - save_pretrain_checkpoint(model, dataset_name) + # Log metrics for this pretraining stage + log_metrics(dataset_name, epoch, train_loss, train_acc1, train_acc5, val_loss, val_acc1, val_acc5, log_dir=log_dir) - return model + # Save best model based on validation accuracy for this stage + if val_acc1 > best_acc: + best_acc = val_acc1 + save_pretrain_checkpoint(model, dataset_name, output_dir=os.path.join(output_dir, 'checkpoints')) + print(f"[{dataset_name}] New best accuracy: {best_acc:.2f}% (Epoch {epoch+1}). Checkpoint saved.") + else: + print(f"[{dataset_name}] Epoch {epoch+1}: Acc@1={val_acc1:.2f}% | Best={best_acc:.2f}%") + print(f"--- Finished Pretraining on {dataset_name} ---") + + # --- Plotting and Final Validation for this Stage --- + log_file = os.path.join(log_dir, f"{dataset_name}_log.csv") + plot_metrics(log_file, output_dir, dataset_name) + + # Load the best checkpoint for this stage for final validation and CM + best_checkpoint_path = os.path.join(output_dir, 'checkpoints', f'{dataset_name}_checkpoint.pth') + if os.path.exists(best_checkpoint_path): + print(f"Loading best {dataset_name} model for final validation...") + # Load only state_dict, don't need optimizer/scheduler here + load_checkpoint(model, filename=best_checkpoint_path, load_optimizer_scheduler=False) + else: + print(f"Warning: Best checkpoint {best_checkpoint_path} not found. Using model from last epoch for validation.") + + + print(f"Running final validation on {dataset_name} to generate Confusion Matrix...") + final_val_loss, final_val_acc1, final_val_acc5, all_preds, all_targets = validate( + model, val_loader, criterion, args.device, return_preds_targets=True + ) + print(f"Final {dataset_name} Validation: Loss={final_val_loss:.4f}, Acc@1={final_val_acc1:.2f}%, Acc@5={final_val_acc5:.2f}%") + + if all_preds is not None and all_targets is not None: + plot_confusion_matrix(all_preds.cpu().numpy(), all_targets.cpu().numpy(), num_classes, output_dir, dataset_name) + else: + print(f"Could not generate confusion matrix for {dataset_name} due to missing prediction data.") + + + # Important: Return the model (potentially loaded with best weights) + return model + ##################################### # Training and Evaluation Functions ##################################### -def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, epoch, device, mixup_fn=None): +def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, epoch, device, args, mixup_fn=None): """Train model for one epoch""" model.train() - + losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() - - # Initialize tqdm progress bar - pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Train]", - leave=True, ncols=100, unit="batch") - - for images, target in pbar: + + steps_per_epoch = len(train_loader) + pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Train]", + leave=False, ncols=100, unit="batch") # Changed leave to False for cleaner nested loops + + for batch_idx, (images, target) in enumerate(pbar): # Move data to device - images = images.to(device) - target = target.to(device) - + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + # Apply mixup or cutmix if available if mixup_fn is not None: images, target = mixup_fn(images, target) - + # Forward pass output = model(images) - loss = criterion(output, target) - - # Measure accuracy and record loss - if mixup_fn is None: # Only measure accuracy if not using mixup/cutmix - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - top1.update(acc1[0].item(), images.size(0)) - top5.update(acc5[0].item(), images.size(0)) - + + # Handle cases where mixup changes target format + if mixup_fn is not None and len(target.shape) > 1: + loss = criterion(output, target) # SoftTargetCrossEntropy handles smoothed labels + # Accuracy calculation is ambiguous with mixup, often skipped or calculated differently + acc1, acc5 = [torch.tensor(0.0), torch.tensor(0.0)] # Placeholder + else: + loss = criterion(output, target) # Standard CE or LabelSmoothing + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + # Update meters losses.update(loss.item(), images.size(0)) - + if mixup_fn is None: # Only update accuracy if not using mixup + top1.update(acc1[0].item(), images.size(0)) + top5.update(acc5[0].item(), images.size(0)) + # Backward pass and optimize optimizer.zero_grad() loss.backward() optimizer.step() + + # Adjust learning rate based on step, not epoch (important for cosine schedule with warmup) scheduler.step() - + # Update progress bar lr = scheduler.get_last_lr()[0] pbar.set_postfix({ @@ -249,47 +459,59 @@ def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, epoch, 'Acc@1': f"{top1.avg:.2f}%" if mixup_fn is None else "N/A", 'LR': f"{lr:.6f}" }) - + + # Return average metrics for the epoch return losses.avg, top1.avg, top5.avg -def validate(model, val_loader, criterion, device): +def validate(model, val_loader, criterion, device, return_preds_targets=False): """Evaluate model on validation set""" model.eval() - + losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() - - # Initialize tqdm progress bar - pbar = tqdm(val_loader, desc="Validation", leave=True, ncols=100, unit="batch") - + + all_preds_list = [] + all_targets_list = [] + + pbar = tqdm(val_loader, desc="Validation", leave=False, ncols=100, unit="batch") # Changed leave to False + with torch.no_grad(): for images, target in pbar: - # Move data to device - images = images.to(device) - target = target.to(device) - - # Forward pass + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + output = model(images) loss = criterion(output, target) - - # Measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) - - # Update meters + losses.update(loss.item(), images.size(0)) top1.update(acc1[0].item(), images.size(0)) top5.update(acc5[0].item(), images.size(0)) - - # Update progress bar + + if return_preds_targets: + preds = torch.argmax(output, dim=1) + all_preds_list.append(preds.cpu()) # Move to CPU immediately + all_targets_list.append(target.cpu()) # Move to CPU immediately + + pbar.set_postfix({ 'Loss': f"{losses.avg:.4f}", 'Acc@1': f"{top1.avg:.2f}%", 'Acc@5': f"{top5.avg:.2f}%" }) - - print(f"* Validation: Acc@1 {top1.avg:.3f}% Acc@5 {top5.avg:.3f}%") - return losses.avg, top1.avg, top5.avg + + all_preds = None + all_targets = None + if return_preds_targets and len(all_preds_list) > 0: + all_preds = torch.cat(all_preds_list) + all_targets = torch.cat(all_targets_list) + + + # No need to print here if called during training loop, will be printed in main loop + # If called standalone (e.g., for final CM), the calling function should print + return losses.avg, top1.avg, top5.avg, all_preds, all_targets ##################################### # Main Training Loop @@ -301,173 +523,284 @@ def main(args): np.random.seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) - torch.backends.cudnn.benchmark = True - - # Create output directory - output_dir = os.path.join(args.output_dir, args.tag if args.tag else '') - os.makedirs(output_dir, exist_ok=True) - - # Create model - model = get_model(model_name='swin_t', efficient=args.efficient, num_classes=1000) - - # === Pretraining Stage: CIFAR === - if args.pretrain_cifar and not args.skip_pretrain: - print("creating dataloaders for cifar100...") - cifar_train_loader = get_cifar_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) - cifar_val_loader = get_cifar_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) - model = pretrain_on_dataset(model, cifar_train_loader, cifar_val_loader, num_classes=100, args=args, dataset_name='cifar100') - - # === Pretraining Stage: Caltech === - if args.pretrain_caltech and not args.skip_pretrain: - print("creating dataloaders for caltech256...") - caltech_train_loader = get_caltech_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) - caltech_val_loader = get_caltech_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) - model = pretrain_on_dataset(model, caltech_train_loader, caltech_val_loader, num_classes=257, args=args, dataset_name='caltech256') - - # === Final Training Stage === - print("creating dataloaders for tinyimagenet...") + torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for speed + + # --- Setup Directories --- + base_output_dir = args.output_dir + experiment_output_dir = os.path.join(base_output_dir, args.tag) if args.tag else base_output_dir + log_dir = args.log_dir # Use dedicated log dir + os.makedirs(experiment_output_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) + checkpoint_dir = os.path.join(experiment_output_dir, 'checkpoints') # Subdir for checkpoints + os.makedirs(checkpoint_dir, exist_ok=True) + + print(f"Output Directory: {experiment_output_dir}") + print(f"Log Directory: {log_dir}") + print(f"Checkpoint Directory: {checkpoint_dir}") + + + # --- Create Model --- + # Start with ImageNet classes, head will be replaced as needed + print("Initializing model...") + model = get_model(model_name='swin_t', efficient=args.efficient, num_classes=1000, pretrained=True) # Load ImageNet pretrained weights + print(f"Model: swin_t (Efficient: {args.efficient}), Pretrained: True") + + + # --- Pretraining Stages --- + if not args.skip_pretrain: + if args.pretrain_cifar: + print("\n>>> Starting CIFAR-100 Pretraining Stage <<<") + cifar_train_loader = get_cifar_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) + cifar_val_loader = get_cifar_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) + model = pretrain_on_dataset(model, cifar_train_loader, cifar_val_loader, + num_classes=100, args=args, dataset_name='cifar100', + output_dir=experiment_output_dir, log_dir=log_dir) + print("\n>>> Finished CIFAR-100 Pretraining Stage <<<") + + + if args.pretrain_caltech: + print("\n>>> Starting Caltech-256 Pretraining Stage <<<") + # If CIFAR pretraining happened, the model already has a head for 100 classes. + # If not, it has the original 1000 class head. `pretrain_on_dataset` handles replacement. + caltech_train_loader = get_caltech_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) + caltech_val_loader = get_caltech_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) + model = pretrain_on_dataset(model, caltech_train_loader, caltech_val_loader, + num_classes=257, args=args, dataset_name='caltech256', + output_dir=experiment_output_dir, log_dir=log_dir) + print("\n>>> Finished Caltech-256 Pretraining Stage <<<") + else: + print("Skipping all pretraining stages as requested.") + + + # --- Final Training Stage: Tiny ImageNet --- + print("\n>>> Starting Final Training Stage: Tiny ImageNet <<<") + print("Creating dataloaders for Tiny ImageNet...") train_loader, val_loader, mixup_fn = get_loaders( batch_size=args.batch_size, num_workers=args.workers, img_size=MODEL_CONFIG["img_size"], use_mixup=args.mixup ) + + print("Replacing model head for Tiny ImageNet (200 classes)...") model = replace_head(model, num_classes=200) model = model.to(args.device) - + # Print model information num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"Number of parameters: {num_params:,}") - - # Create optimizer + print(f"Number of trainable parameters (final stage): {num_params:,}") + + # Create optimizer and scheduler for the final stage optimizer = create_optimizer(model, args.lr, args.weight_decay) - - - # Create scheduler + steps_per_epoch_main = len(train_loader) scheduler = create_scheduler( optimizer, num_epochs=args.epochs, - steps_per_epoch=len(train_loader), + steps_per_epoch=steps_per_epoch_main, + base_lr=args.lr, warmup_epochs=args.warmup_epochs, min_lr=args.min_lr ) - - # Create loss function + + # Create loss function for the final stage if args.mixup: - # Use soft target cross entropy loss for mixup/cutmix - criterion = SoftTargetCrossEntropy() + criterion = SoftTargetCrossEntropy().to(args.device) + print("Using Mixup/Cutmix augmentation with SoftTargetCrossEntropy loss.") else: - # Use label smoothing cross entropy loss - criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing) - - # Optionally resume from checkpoint + criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing).to(args.device) + print(f"Using Label Smoothing Cross Entropy loss (smoothing={args.label_smoothing}).") + + # Optionally resume from a final stage checkpoint start_epoch = 0 best_acc = 0.0 if args.resume: - if os.path.isfile(args.resume): - start_epoch, best_acc = load_checkpoint(model, optimizer, scheduler, args.resume) - print(f"Resumed from epoch {start_epoch} with accuracy {best_acc:.2f}%") + resume_path = args.resume if os.path.isabs(args.resume) else os.path.join(checkpoint_dir, args.resume) + if os.path.isfile(resume_path): + print(f"Attempting to resume final stage training from: {resume_path}") + # Load optimizer and scheduler state when resuming main training + start_epoch, best_acc = load_checkpoint(model, optimizer, scheduler, resume_path, load_optimizer_scheduler=True) + print(f"Resumed final stage from epoch {start_epoch}. Previous best accuracy: {best_acc:.2f}%") + start_epoch = start_epoch # Checkpoint saves epoch+1, so start from the returned value else: - print(f"No checkpoint found at {args.resume}") - - # Evaluation only + print(f"Resume checkpoint not found at '{resume_path}'. Starting final training from scratch.") + + # Evaluation only mode for the final model if args.evaluate: - print("Running evaluation") - validate(model, val_loader, criterion, args.device) - return - - # Print training configuration - print(f"Starting training for {args.epochs} epochs") - print(f"Batch size: {args.batch_size}") - print(f"Learning rate: {args.lr}") - print(f"Weight decay: {args.weight_decay}") - print(f"Using mixup: {args.mixup}") - - # Training loop + print("--- Running Evaluation Only Mode ---") + eval_checkpoint_path = args.resume if args.resume else os.path.join(experiment_output_dir, 'model_best.pth') + if os.path.isfile(eval_checkpoint_path): + print(f"Loading model from: {eval_checkpoint_path} for evaluation...") + # Don't load optimizer/scheduler for evaluation + load_checkpoint(model, filename=eval_checkpoint_path, load_optimizer_scheduler=False) + print("Running validation...") + val_loss, val_acc1, val_acc5, all_preds, all_targets = validate( + model, val_loader, criterion, args.device, return_preds_targets=True + ) + print(f"\nEvaluation Results (Tiny ImageNet):") + print(f" Loss: {val_loss:.4f}") + print(f" Acc@1: {val_acc1:.2f}%") + print(f" Acc@5: {val_acc5:.2f}%") + + # Plot confusion matrix for evaluation + if all_preds is not None and all_targets is not None: + plot_confusion_matrix(all_preds.cpu().numpy(), all_targets.cpu().numpy(), 200, experiment_output_dir, "main_eval") + else: + print("Could not generate confusion matrix due to missing prediction data.") + + else: + print(f"Evaluation checkpoint '{eval_checkpoint_path}' not found. Cannot evaluate.") + return # Exit after evaluation + + + # --- Main Training Loop --- + print(f"\n--- Starting Final Training Loop (Tiny ImageNet) for {args.epochs - start_epoch} epochs ---") + print(f"Batch size: {args.batch_size}, Initial LR: {args.lr}, Weight Decay: {args.weight_decay}, Mixup: {args.mixup}") + for epoch in range(start_epoch, args.epochs): epoch_start = time.time() - - # Train for one epoch + print(f"\n--- Tiny ImageNet Epoch {epoch+1}/{args.epochs} ---") + + # Train train_loss, train_acc1, train_acc5 = train_one_epoch( - model, train_loader, criterion, optimizer, scheduler, epoch, args.device, mixup_fn) - - # Evaluate on validation set - val_loss, val_acc1, val_acc5 = validate(model, val_loader, criterion, args.device) - - # Calculate epoch time + model, train_loader, criterion, optimizer, scheduler, epoch, args.device, args, mixup_fn + ) + + # Validate + val_loss, val_acc1, val_acc5, _, _ = validate( # Don't need preds/targets here + model, val_loader, criterion, args.device, return_preds_targets=False + ) + epoch_time = time.time() - epoch_start - - # Save checkpoint + + # Check if current epoch is best is_best = val_acc1 > best_acc - best_acc = max(val_acc1, best_acc) - - # Log metrics to CSV - lr = scheduler.get_last_lr()[0] - log_metrics("main", epoch, train_loss, train_acc1, val_loss, val_acc1, log_dir=args.log_dir) + if is_best: + old_best = best_acc + best_acc = val_acc1 + print(f"*** New Best Accuracy: {best_acc:.2f}% (Improved from {old_best:.2f}%) ***") + else: + print(f"Validation Acc@1: {val_acc1:.2f}% (Best: {best_acc:.2f}%)") + + + # Log metrics to CSV for the main training stage + log_metrics("main", epoch, train_loss, train_acc1, train_acc5, val_loss, val_acc1, val_acc5, log_dir=log_dir) - - # Print epoch summary - print(f"Epoch {epoch+1}/{args.epochs} | Time: {epoch_time:.2f}s") - print(f" Train: Loss {train_loss:.4f}, Acc@1 {train_acc1:.2f}%, Acc@5 {train_acc5:.2f}%") - print(f" Valid: Loss {val_loss:.4f}, Acc@1 {val_acc1:.2f}%, Acc@5 {val_acc5:.2f}%") - print(f" Best accuracy: {best_acc:.2f}%") - # Save checkpoint periodically if (epoch + 1) % args.save_interval == 0: save_checkpoint( model, optimizer, scheduler, epoch + 1, val_acc1, - os.path.join(output_dir, f'checkpoint_epoch{epoch+1}.pth') + os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch+1}.pth') ) - + # Always save the best model if is_best: save_checkpoint( - model, optimizer, scheduler, epoch + 1, val_acc1, - os.path.join(output_dir, 'model_best.pth') + model, optimizer, scheduler, epoch + 1, best_acc, # Save best_acc here + os.path.join(experiment_output_dir, 'model_best.pth') # Save best model in parent dir ) - print(f"Training complete. Best accuracy: {best_acc:.2f}%") + # Print epoch summary + print(f"Epoch {epoch+1} Summary | Time: {epoch_time:.2f}s | LR: {scheduler.get_last_lr()[0]:.6f}") + print(f" Train -> Loss: {train_loss:.4f}, Acc@1: {train_acc1:.2f}%" if not args.mixup else f" Train -> Loss: {train_loss:.4f}, Acc@1: N/A (Mixup)") + print(f" Valid -> Loss: {val_loss:.4f}, Acc@1: {val_acc1:.2f}%, Acc@5: {val_acc5:.2f}%") + + + print(f"\n--- Finished Final Training Stage (Tiny ImageNet) ---") + print(f"Best validation accuracy achieved: {best_acc:.2f}%") + + # --- Final Plotting and Confusion Matrix for Main Training --- + # Plot loss/accuracy curves for the main training + main_log_file = os.path.join(log_dir, "main_log.csv") + plot_metrics(main_log_file, experiment_output_dir, "main") + + # Load the *best* model for the final confusion matrix + best_model_path = os.path.join(experiment_output_dir, 'model_best.pth') + if os.path.exists(best_model_path): + print(f"Loading best model from {best_model_path} for final confusion matrix...") + # Create a fresh instance or reload into the current one + # Re-create model to ensure clean state if needed, though loading state_dict should be fine + final_model = get_model(model_name='swin_t', efficient=args.efficient, num_classes=200) # Head already replaced + final_model = replace_head(final_model, num_classes=200) + load_checkpoint(final_model, filename=best_model_path, load_optimizer_scheduler=False) + final_model.to(args.device) + final_model.eval() + + print("Running validation on best model for final confusion matrix...") + _, _, _, all_preds, all_targets = validate( + final_model, val_loader, criterion, args.device, return_preds_targets=True + ) + + if all_preds is not None and all_targets is not None: + plot_confusion_matrix(all_preds.cpu().numpy(), all_targets.cpu().numpy(), 200, experiment_output_dir, "main_best_model") + else: + print("Could not generate final confusion matrix due to missing prediction data.") + + else: + print(f"Best model checkpoint '{best_model_path}' not found. Cannot generate confusion matrix for the best model.") + + print("\n>>> All Training Stages Complete <<<") + def parse_args(): - parser = argparse.ArgumentParser(description='Swin Transformer for Tiny ImageNet') - - # Model parameters - parser.add_argument('--efficient', action='store_true', help='Use efficient model variant') - - # pre-training... - parser.add_argument('--pretrain_cifar', action='store_true', help='Pretrain on CIFAR first') - parser.add_argument('--pretrain_caltech', action='store_true', help='Pretrain on Caltech after CIFAR') - parser.add_argument('--skip-pretrain', action='store_true', help='Skip all pretraining stages') - - # Training parameters - parser.add_argument('--batch-size', type=int, default=TRAIN_CONFIG['batch_size'], help='Batch size') - parser.add_argument('--epochs', type=int, default=TRAIN_CONFIG['epochs'], help='Number of epochs') - parser.add_argument('--lr', '--learning-rate', type=float, default=TRAIN_CONFIG['learning_rate'], help='Learning rate') - parser.add_argument('--min-lr', type=float, default=TRAIN_CONFIG['min_lr'], help='Minimum learning rate') - parser.add_argument('--warmup-epochs', type=int, default=TRAIN_CONFIG['warmup_epochs'], help='Warmup epochs') - parser.add_argument('--weight-decay', type=float, default=TRAIN_CONFIG['weight_decay'], help='Weight decay') - parser.add_argument('--label-smoothing', type=float, default=TRAIN_CONFIG['label_smoothing'], help='Label smoothing factor') - parser.add_argument('--device', default=TRAIN_CONFIG['device'], help='Device to use') - parser.add_argument('--mixup', action='store_true', help='Use mixup and cutmix augmentation') - - # Data loading + parser = argparse.ArgumentParser(description='Swin Transformer Training with Pretraining Options') + + # --- Model --- + parser.add_argument('--model-name', type=str, default='swin_t', help='Name of the model architecture (e.g., swin_t)') + parser.add_argument('--efficient', action='store_true', help='Use efficient model variant (if available)') + parser.add_argument('--no-pretrained', action='store_true', help='Do not use ImageNet pretrained weights initially') + + # --- Pre-training --- + parser.add_argument('--pretrain-cifar', action='store_true', help='Pretrain on CIFAR-100 first') + parser.add_argument('--pretrain-caltech', action='store_true', help='Pretrain on Caltech-256 (after CIFAR if specified, otherwise from ImageNet)') + parser.add_argument('--skip-pretrain', action='store_true', help='Skip all pretraining stages and train directly on Tiny ImageNet') + # parser.add_argument('--pretrain-epochs', type=int, default=30, help='Number of epochs for each pretraining stage (if different from main epochs)') # Optional: Separate epoch control + + # --- Main Training --- + parser.add_argument('--batch-size', type=int, default=TRAIN_CONFIG['batch_size'], help='Input batch size for training') + parser.add_argument('--epochs', type=int, default=TRAIN_CONFIG['epochs'], help='Number of epochs to train') + parser.add_argument('--lr', '--learning-rate', type=float, default=TRAIN_CONFIG['learning_rate'], help='Initial learning rate') + parser.add_argument('--min-lr', type=float, default=TRAIN_CONFIG['min_lr'], help='Minimum learning rate for scheduler') + parser.add_argument('--warmup-epochs', type=int, default=TRAIN_CONFIG['warmup_epochs'], help='Number of warmup epochs') + parser.add_argument('--weight-decay', type=float, default=TRAIN_CONFIG['weight_decay'], help='Optimizer weight decay') + parser.add_argument('--label-smoothing', type=float, default=TRAIN_CONFIG['label_smoothing'], help='Label smoothing factor (if not using mixup)') + parser.add_argument('--mixup', action='store_true', help='Use mixup and cutmix augmentation (disables label smoothing)') + + # --- Data & Device --- + parser.add_argument('--img-size', type=int, default=MODEL_CONFIG['img_size'], help='Input image size') # Make img_size configurable if needed parser.add_argument('--workers', type=int, default=4, help='Number of data loading workers') - - # Checkpointing - parser.add_argument('--output-dir', default=TRAIN_CONFIG['output_dir'], help='Path to save output') - parser.add_argument('--tag', default='', help='Tag for the experiment') - parser.add_argument('--save-interval', type=int, default=TRAIN_CONFIG['save_interval'], help='Save checkpoint every N epochs') - parser.add_argument('--resume', default='', help='Resume from checkpoint') - - # logs - parser.add_argument('--log-dir', default='logs', type=str, help='Directory to save all training logs') - - # Misc - parser.add_argument('--seed', type=int, default=42, help='Random seed') - parser.add_argument('--evaluate', action='store_true', help='Evaluate only') - - return parser.parse_args() + parser.add_argument('--device', default=TRAIN_CONFIG['device'], help='Device to use (e.g., "cuda", "cpu")') + + # --- Checkpointing & Logging --- + parser.add_argument('--output-dir', default=TRAIN_CONFIG['output_dir'], help='Base directory to save checkpoints and logs') + parser.add_argument('--log-dir', default=TRAIN_CONFIG['log_dir'], help='Directory within output-dir to save CSV logs and plots') + parser.add_argument('--tag', default='', type=str, help='Optional tag for experiment directory name') + parser.add_argument('--save-interval', type=int, default=TRAIN_CONFIG['save_interval'], help='Save checkpoint every N epochs during main training') + parser.add_argument('--resume', default='', type=str, metavar='PATH', help='Path to latest checkpoint to resume main training (or for evaluation)') + + # --- Misc --- + parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility') + parser.add_argument('--evaluate', action='store_true', help='Perform evaluation only on the validation set (requires --resume or finds model_best.pth)') + + args = parser.parse_args() + + # Set device based on argument + args.device = torch.device(args.device if torch.cuda.is_available() else "cpu") + + # If mixup is used, disable label smoothing effect by setting it to 0 + if args.mixup: + args.label_smoothing = 0.0 + print("Mixup enabled, label smoothing set to 0.") + + # Ensure log_dir is inside output_dir unless absolute path is given + if not os.path.isabs(args.log_dir): + args.log_dir = os.path.join(args.output_dir, args.tag if args.tag else '', args.log_dir) + + + return args + if __name__ == '__main__': args = parse_args() + # Ensure necessary directories exist based on final paths + os.makedirs(args.log_dir, exist_ok=True) main(args) \ No newline at end of file From 502a874d47f14653bc0ea42f12b785288a9785b1 Mon Sep 17 00:00:00 2001 From: us3r247 <81853860+us3r247@users.noreply.github.com> Date: Wed, 16 Apr 2025 02:33:57 +0530 Subject: [PATCH 3/7] fixed get_model() calls --- train.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 618a6c2..4abd4be 100644 --- a/train.py +++ b/train.py @@ -540,10 +540,9 @@ def main(args): # --- Create Model --- - # Start with ImageNet classes, head will be replaced as needed print("Initializing model...") - model = get_model(model_name='swin_t', efficient=args.efficient, num_classes=1000, pretrained=True) # Load ImageNet pretrained weights - print(f"Model: swin_t (Efficient: {args.efficient}), Pretrained: True") + model = get_model(model_name='swin_t', efficient=args.efficient) + print(f"Model: swin_t (Efficient: {args.efficient})") # --- Pretraining Stages --- @@ -719,8 +718,9 @@ def main(args): print(f"Loading best model from {best_model_path} for final confusion matrix...") # Create a fresh instance or reload into the current one # Re-create model to ensure clean state if needed, though loading state_dict should be fine - final_model = get_model(model_name='swin_t', efficient=args.efficient, num_classes=200) # Head already replaced - final_model = replace_head(final_model, num_classes=200) + final_model = get_model(model_name='swin_t', efficient=args.efficient) # Head already replaced + # final_model = replace_head(final_model, num_classes=200) + load_checkpoint(final_model, filename=best_model_path, load_optimizer_scheduler=False) final_model.to(args.device) final_model.eval() @@ -747,7 +747,7 @@ def parse_args(): # --- Model --- parser.add_argument('--model-name', type=str, default='swin_t', help='Name of the model architecture (e.g., swin_t)') parser.add_argument('--efficient', action='store_true', help='Use efficient model variant (if available)') - parser.add_argument('--no-pretrained', action='store_true', help='Do not use ImageNet pretrained weights initially') + # parser.add_argument('--no-pretrained', action='store_true', help='Do not use ImageNet pretrained weights initially') # --- Pre-training --- parser.add_argument('--pretrain-cifar', action='store_true', help='Pretrain on CIFAR-100 first') @@ -803,4 +803,4 @@ def parse_args(): args = parse_args() # Ensure necessary directories exist based on final paths os.makedirs(args.log_dir, exist_ok=True) - main(args) \ No newline at end of file + main(args) From 8a56c1c809ed08440a1f14089ab7c24f07a21757 Mon Sep 17 00:00:00 2001 From: Yash Agarwal <152529238+Yash-Agarwal-BITS@users.noreply.github.com> Date: Wed, 16 Apr 2025 02:44:40 +0530 Subject: [PATCH 4/7] Update train.py --- train.py | 862 ++++++++++++++++++++++++++----------------------------- 1 file changed, 415 insertions(+), 447 deletions(-) diff --git a/train.py b/train.py index 4abd4be..f8089ee 100644 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ import time import argparse import numpy as np +import pandas as pd # Added for reading logs import torch import torch.nn as nn import torch.nn.functional as F @@ -9,9 +10,7 @@ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy import csv from tqdm import tqdm -import pandas as pd # Added for reading logs import matplotlib.pyplot as plt # Added for plotting -import seaborn as sns # Added for confusion matrix styling from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay # Added for CM # Import custom modules @@ -23,7 +22,7 @@ # Training Configuration TRAIN_CONFIG = { "batch_size": 128, - "epochs": 100, # Reduced for faster demonstration if needed + "epochs": 100, "learning_rate": 5e-4, "min_lr": 5e-6, "weight_decay": 0.05, @@ -32,8 +31,7 @@ "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), "output_dir": "output", "log_interval": 20, - "save_interval": 10, # Reduced for faster demonstration - "log_dir": "logs", # Added log dir to config + "save_interval": 10, } ##################################### @@ -55,10 +53,8 @@ def update(self, val, n=1): self.val = val self.sum += val * n self.count += n - if self.count > 0: - self.avg = self.sum / self.count - else: - self.avg = 0 + # Prevent division by zero if count is 0 + self.avg = self.sum / self.count if self.count > 0 else 0 def accuracy(output, target, topk=(1,)): @@ -66,8 +62,6 @@ def accuracy(output, target, topk=(1,)): with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) - if batch_size == 0: - return [torch.tensor(0.0) for _ in topk] _, pred = output.topk(maxk, 1, True, True) pred = pred.t() @@ -87,18 +81,13 @@ def lr_lambda(current_step): progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) - # Ensure the final learning rate doesn't go below min_lr + # Ensure the final learning rate is at least min_lr return max(min_lr / base_lr, cosine_decay) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) - def save_checkpoint(model, optimizer, scheduler, epoch, accuracy, filename): """Save model checkpoint""" - if not os.path.exists(os.path.dirname(filename)): - os.makedirs(os.path.dirname(filename), exist_ok=True) - print(f"Created directory: {os.path.dirname(filename)}") - state = { 'epoch': epoch, 'model': model.state_dict(), @@ -111,55 +100,51 @@ def save_checkpoint(model, optimizer, scheduler, epoch, accuracy, filename): print(f"Checkpoint saved to {filename}") def save_pretrain_checkpoint(model, dataset_name, output_dir='checkpoints'): - """Save model checkpoint after pretraining""" os.makedirs(output_dir, exist_ok=True) path = os.path.join(output_dir, f'{dataset_name}_checkpoint.pth') - # Save only the model state dict for pretraining checkpoints torch.save(model.state_dict(), path) print(f"Saved pretraining checkpoint for {dataset_name} at: {path}") -def load_checkpoint(model, optimizer=None, scheduler=None, filename=None, load_optimizer_scheduler=True): +def load_checkpoint(model, optimizer=None, scheduler=None, filename=None): """Load checkpoint from file""" - if not filename or not os.path.isfile(filename): + if not os.path.isfile(filename): print(f"No checkpoint found at {filename}") return 0, 0.0 print(f"Loading checkpoint from {filename}") checkpoint = torch.load(filename, map_location='cpu') - # Handle both full checkpoints and state_dict-only checkpoints - if 'model' in checkpoint: - model.load_state_dict(checkpoint['model']) - else: - # Assume it's just a state_dict - model.load_state_dict(checkpoint) - # If only state_dict loaded, cannot resume optimizer/scheduler/epoch - print("Loaded model state_dict only. Cannot resume optimizer, scheduler, or epoch.") - return 0, checkpoint.get('accuracy', 0.0) # Return 0 epoch, try to get accuracy + # Adjust for potential DataParallel prefix 'module.' + state_dict = checkpoint['model'] + new_state_dict = {} + for k, v in state_dict.items(): + name = k[7:] if k.startswith('module.') else k # remove `module.` prefix + new_state_dict[name] = v + model.load_state_dict(new_state_dict) - epoch = checkpoint.get('epoch', 0) - accuracy = checkpoint.get('accuracy', 0.0) - if load_optimizer_scheduler: - if optimizer is not None and 'optimizer' in checkpoint: - try: - optimizer.load_state_dict(checkpoint['optimizer']) - except Exception as e: - print(f"Could not load optimizer state: {e}. Continuing without loading optimizer.") + if optimizer is not None and 'optimizer' in checkpoint: + try: + optimizer.load_state_dict(checkpoint['optimizer']) + except ValueError as e: + print(f"Could not load optimizer state: {e}. This might happen if model structure changed.") - if scheduler is not None and 'scheduler' in checkpoint and checkpoint['scheduler'] is not None: - try: - scheduler.load_state_dict(checkpoint['scheduler']) - except Exception as e: - print(f"Could not load scheduler state: {e}. Continuing without loading scheduler.") + if scheduler is not None and 'scheduler' in checkpoint and checkpoint['scheduler'] is not None: + try: + scheduler.load_state_dict(checkpoint['scheduler']) + except KeyError as e: + print(f"Could not load scheduler state: {e}. This might happen if scheduler type changed.") - else: - print("Skipping loading optimizer and scheduler state.") + start_epoch = checkpoint.get('epoch', 0) + accuracy = checkpoint.get('accuracy', 0.0) + + print(f"Checkpoint loaded. Resuming from epoch {start_epoch} with validation accuracy {accuracy:.2f}%") + + return start_epoch, accuracy - return epoch, accuracy def create_optimizer(model, lr, weight_decay): """Create optimizer for model""" @@ -185,7 +170,6 @@ def create_scheduler(optimizer, num_epochs, steps_per_epoch, base_lr, warmup_epo def log_metrics(dataset_name, epoch, train_loss, train_acc1, train_acc5, val_loss, val_acc1, val_acc5, log_dir="logs"): - """Logs training and validation metrics to a CSV file.""" os.makedirs(log_dir, exist_ok=True) log_file = os.path.join(log_dir, f"{dataset_name}_log.csv") write_header = not os.path.exists(log_file) @@ -199,212 +183,181 @@ def log_metrics(dataset_name, epoch, train_loss, train_acc1, train_acc5, val_los "val_loss", "val_acc1", "val_acc5" ]) writer.writerow([ - epoch + 1, # Log 1-based epoch - f"{train_loss:.4f}" if train_loss is not None else "N/A", - f"{train_acc1:.2f}" if train_acc1 is not None else "N/A", - f"{train_acc5:.2f}" if train_acc5 is not None else "N/A", - f"{val_loss:.4f}" if val_loss is not None else "N/A", - f"{val_acc1:.2f}" if val_acc1 is not None else "N/A", - f"{val_acc5:.2f}" if val_acc5 is not None else "N/A" + epoch + 1, + train_loss if train_loss is not None else 'N/A', # Handle None cases + train_acc1 if train_acc1 is not None else 'N/A', + train_acc5 if train_acc5 is not None else 'N/A', + val_loss, val_acc1, val_acc5 ]) def replace_head(model, num_classes): - """Replaces the classification head of the model.""" - in_features = 0 - if hasattr(model, 'head') and hasattr(model.head, 'in_features'): - in_features = model.head.in_features - elif hasattr(model, 'fc') and hasattr(model.fc, 'in_features'): # common alternative name - in_features = model.fc.in_features - elif hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear): # Another common name - in_features = model.classifier.in_features - elif hasattr(model, 'num_features'): # Timm models often have this property - in_features = model.num_features - else: - raise AttributeError("Cannot determine the input features of the model's classification head. Tried 'head', 'fc', 'classifier'.") - - # Replace the head - if hasattr(model, 'head'): - model.head = nn.Linear(in_features, num_classes) - elif hasattr(model, 'fc'): - model.fc = nn.Linear(in_features, num_classes) - elif hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear): - model.classifier = nn.Linear(in_features, num_classes) - else: - # Fallback for models where head is not explicitly named 'head', 'fc' or 'classifier' - # This might need adjustment based on the specific architecture if get_model returns something unusual - print("Warning: Replacing head using a generic approach based on Timm's num_features. Ensure this is correct for the model.") - model.head = nn.Linear(in_features, num_classes) # Assume we can add a 'head' attribute - - print(f"Replaced model head with a new one for {num_classes} classes.") + in_features = model.head.in_features # Get input features from existing head + model.head = nn.Linear(in_features, num_classes) + print(f"Replaced model head for {num_classes} classes.") return model ##################################### -# Plotting Functions +# Plotting Utilities # ##################################### def plot_metrics(log_file, output_dir, dataset_name): - """Plots loss and accuracy curves from a log file.""" - if not os.path.exists(log_file): - print(f"Log file not found: {log_file}. Skipping plotting.") - return - + """Plots training/validation loss and accuracy from log file.""" try: df = pd.read_csv(log_file) + df.replace('N/A', np.nan, inplace=True) # Replace 'N/A' with NaN + df = df.astype({ # Convert relevant columns to numeric, errors='coerce' turns failures into NaN + 'train_loss': float, 'train_acc1': float, 'train_acc5': float, + 'val_loss': float, 'val_acc1': float, 'val_acc5': float + }, errors='coerce') + + + epochs = df['epoch'] + + plt.style.use('seaborn-v0_8-grid') # Use a nice style + fig, axes = plt.subplots(1, 2, figsize=(15, 5)) + + # Plot Loss + axes[0].plot(epochs, df['train_loss'], 'bo-', label='Train Loss') + axes[0].plot(epochs, df['val_loss'], 'ro-', label='Validation Loss') + axes[0].set_title(f'{dataset_name} - Loss vs. Epochs') + axes[0].set_xlabel('Epochs') + axes[0].set_ylabel('Loss') + axes[0].legend() + axes[0].grid(True) + + # Plot Accuracy + if df['train_acc1'].notna().any(): # Only plot train acc if data exists + axes[1].plot(epochs, df['train_acc1'], 'bo-', label='Train Accuracy@1') + axes[1].plot(epochs, df['val_acc1'], 'ro-', label='Validation Accuracy@1') + # Optionally plot Acc@5 + # if df['train_acc5'].notna().any(): + # axes[1].plot(epochs, df['train_acc5'], 'b--', label='Train Accuracy@5') + # axes[1].plot(epochs, df['val_acc5'], 'r--', label='Validation Accuracy@5') + axes[1].set_title(f'{dataset_name} - Accuracy vs. Epochs') + axes[1].set_xlabel('Epochs') + axes[1].set_ylabel('Accuracy (%)') + axes[1].legend() + axes[1].grid(True) + + plt.tight_layout() + plot_filename = os.path.join(output_dir, f'{dataset_name}_metrics_plot.png') + plt.savefig(plot_filename) + print(f"Metrics plot saved to {plot_filename}") + plt.close(fig) # Close the figure to free memory + + except FileNotFoundError: + print(f"Log file not found at {log_file}, skipping metrics plot.") except Exception as e: - print(f"Error reading log file {log_file}: {e}. Skipping plotting.") - return + print(f"Could not generate metrics plot: {e}") - if df.empty: - print(f"Log file {log_file} is empty. Skipping plotting.") - return - plt.style.use('seaborn-v0_8-grid') # Use a nice style - fig, ax1 = plt.subplots(figsize=(12, 6)) - - # Plot Loss - color = 'tab:red' - ax1.set_xlabel('Epoch') - ax1.set_ylabel('Loss', color=color) - if 'train_loss' in df.columns and pd.to_numeric(df['train_loss'], errors='coerce').notna().any(): - ax1.plot(df['epoch'], pd.to_numeric(df['train_loss'], errors='coerce'), label='Train Loss', color=color, linestyle='--') - if 'val_loss' in df.columns and pd.to_numeric(df['val_loss'], errors='coerce').notna().any(): - ax1.plot(df['epoch'], pd.to_numeric(df['val_loss'], errors='coerce'), label='Validation Loss', color=color) - ax1.tick_params(axis='y', labelcolor=color) - ax1.legend(loc='upper left') - - # Plot Accuracy - ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis - color = 'tab:blue' - ax2.set_ylabel('Accuracy (%)', color=color) - if 'train_acc1' in df.columns and pd.to_numeric(df['train_acc1'], errors='coerce').notna().any(): - ax2.plot(df['epoch'], pd.to_numeric(df['train_acc1'], errors='coerce'), label='Train Acc@1', color=color, linestyle='--') - if 'val_acc1' in df.columns and pd.to_numeric(df['val_acc1'], errors='coerce').notna().any(): - ax2.plot(df['epoch'], pd.to_numeric(df['val_acc1'], errors='coerce'), label='Validation Acc@1', color=color) - ax2.tick_params(axis='y', labelcolor=color) - ax2.legend(loc='lower left') - - plt.title(f'{dataset_name} - Training & Validation Metrics') - fig.tight_layout() # otherwise the right y-label is slightly clipped - - # Save plot - plot_filename = os.path.join(output_dir, f"{dataset_name}_metrics_plot.png") - plt.savefig(plot_filename) - print(f"Metrics plot saved to {plot_filename}") - plt.close(fig) # Close the figure to free memory - -def plot_confusion_matrix(all_preds, all_targets, num_classes, output_dir, dataset_name): - """Computes and plots the confusion matrix.""" - if all_preds is None or all_targets is None: - print(f"No prediction data available for {dataset_name}. Skipping confusion matrix.") +def plot_confusion_matrix(all_preds, all_targets, class_names, output_dir, dataset_name): + """Plots the confusion matrix.""" + if not all_preds or not all_targets: + print("No predictions or targets found, skipping confusion matrix.") return - cm = confusion_matrix(all_targets, all_preds, labels=np.arange(num_classes)) - - # Determine figure size based on number of classes - figsize = max(8, num_classes // 5) # Adjust divisor as needed - - fig, ax = plt.subplots(figsize=(figsize, figsize)) - disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=np.arange(num_classes)) - - # Determine whether to show values based on matrix size - show_values = num_classes <= 30 # Only show values for smaller matrices - - disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation='vertical', values_format='d' if show_values else None) # Only show numbers if show_values is True + try: + cm = confusion_matrix(all_targets, all_preds) + # Adjust figure size based on number of classes + figsize = max(8, len(class_names) // 6) # Heuristic for figure size + fig, ax = plt.subplots(figsize=(figsize, figsize)) - plt.title(f'{dataset_name} - Confusion Matrix') - plt.tight_layout() + disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names) + disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation='vertical', values_format='d') - # Save plot - cm_filename = os.path.join(output_dir, f"{dataset_name}_confusion_matrix.png") - plt.savefig(cm_filename) - print(f"Confusion matrix saved to {cm_filename}") - plt.close(fig) # Close the figure + ax.set_title(f'{dataset_name} - Confusion Matrix') + plt.tight_layout() # Adjust layout to prevent overlap + cm_filename = os.path.join(output_dir, f'{dataset_name}_confusion_matrix.png') + plt.savefig(cm_filename) + print(f"Confusion matrix saved to {cm_filename}") + plt.close(fig) # Close the figure + except Exception as e: + print(f"Could not generate confusion matrix: {e}") ##################################### -# Pretraining Function +# Pretraining Function # ##################################### def pretrain_on_dataset(model, train_loader, val_loader, num_classes, args, dataset_name, output_dir, log_dir): - """Pretrains the model on a given dataset (CIFAR or Caltech).""" - print(f"\n=== Pretraining Stage: {dataset_name} ===") - model = replace_head(model, num_classes) # Ensure head matches dataset + model = replace_head(model, num_classes) model.to(args.device) + # Handle DataParallel if multiple GPUs are available + if torch.cuda.device_count() > 1: + print(f"Using {torch.cuda.device_count()} GPUs for pretraining on {dataset_name}") + model = nn.DataParallel(model) + optimizer = create_optimizer(model, args.lr, args.weight_decay) + # Ensure steps_per_epoch is calculated correctly steps_per_epoch = len(train_loader) + if steps_per_epoch == 0: + print(f"Warning: train_loader for {dataset_name} is empty!") + return model # Cannot train with empty loader + scheduler = create_scheduler( optimizer, - num_epochs=args.epochs, # Use main epoch count for pretraining? Or specific pretrain epochs? Using main for now. + num_epochs=args.epochs, steps_per_epoch=steps_per_epoch, - base_lr=args.lr, + base_lr=args.lr, # Pass base LR warmup_epochs=args.warmup_epochs, min_lr=args.min_lr ) - # Use label smoothing for pretraining as well criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing).to(args.device) + print(f"\n=== Pretraining on {dataset_name} ({num_classes} classes) ===") best_acc = 0.0 - for epoch in range(args.epochs): # Use same number of epochs as main training? - print(f"\n--- {dataset_name} Epoch {epoch+1}/{args.epochs} ---") + pretrain_output_dir = os.path.join(output_dir, f"pretrain_{dataset_name}") + os.makedirs(pretrain_output_dir, exist_ok=True) + log_file_path = os.path.join(log_dir, f"{dataset_name}_log.csv") + + # --- Training Loop --- + for epoch in range(args.epochs): train_loss, train_acc1, train_acc5 = train_one_epoch( - model, train_loader, criterion, optimizer, scheduler, epoch, args.device, args=args + model, train_loader, criterion, optimizer, scheduler, epoch, args.epochs, args.device, dataset_name=f"Pretrain {dataset_name}", mixup_fn=None # No mixup for pretrain ) - val_loss, val_acc1, val_acc5, _, _ = validate( # Get metrics only during training loop - model, val_loader, criterion, args.device, return_preds_targets=False + val_loss, val_acc1, val_acc5, _, _ = validate( # Discard preds/targets here + model, val_loader, criterion, args.device, dataset_name=f"Pretrain {dataset_name}" ) - # Log metrics for this pretraining stage - log_metrics(dataset_name, epoch, train_loss, train_acc1, train_acc5, val_loss, val_acc1, val_acc5, log_dir=log_dir) - - # Save best model based on validation accuracy for this stage - if val_acc1 > best_acc: - best_acc = val_acc1 - save_pretrain_checkpoint(model, dataset_name, output_dir=os.path.join(output_dir, 'checkpoints')) - print(f"[{dataset_name}] New best accuracy: {best_acc:.2f}% (Epoch {epoch+1}). Checkpoint saved.") - else: - print(f"[{dataset_name}] Epoch {epoch+1}: Acc@1={val_acc1:.2f}% | Best={best_acc:.2f}%") - - - print(f"--- Finished Pretraining on {dataset_name} ---") + best_acc = max(val_acc1, best_acc) + print(f"[{dataset_name}] Epoch {epoch+1}: Val Acc@1={val_acc1:.2f}% | Best Val Acc@1={best_acc:.2f}%") - # --- Plotting and Final Validation for this Stage --- - log_file = os.path.join(log_dir, f"{dataset_name}_log.csv") - plot_metrics(log_file, output_dir, dataset_name) - - # Load the best checkpoint for this stage for final validation and CM - best_checkpoint_path = os.path.join(output_dir, 'checkpoints', f'{dataset_name}_checkpoint.pth') - if os.path.exists(best_checkpoint_path): - print(f"Loading best {dataset_name} model for final validation...") - # Load only state_dict, don't need optimizer/scheduler here - load_checkpoint(model, filename=best_checkpoint_path, load_optimizer_scheduler=False) - else: - print(f"Warning: Best checkpoint {best_checkpoint_path} not found. Using model from last epoch for validation.") + # Log metrics + log_metrics(dataset_name, epoch, train_loss, train_acc1, train_acc5, val_loss, val_acc1, val_acc5, log_dir=log_dir) + # Save checkpoint (optional, can save best only) + # save_checkpoint(model.module if isinstance(model, nn.DataParallel) else model, optimizer, scheduler, epoch + 1, val_acc1, + # os.path.join(pretrain_output_dir, f'ckpt_epoch_{epoch+1}.pth')) - print(f"Running final validation on {dataset_name} to generate Confusion Matrix...") - final_val_loss, final_val_acc1, final_val_acc5, all_preds, all_targets = validate( - model, val_loader, criterion, args.device, return_preds_targets=True - ) - print(f"Final {dataset_name} Validation: Loss={final_val_loss:.4f}, Acc@1={final_val_acc1:.2f}%, Acc@5={final_val_acc5:.2f}%") + # --- Save Final/Best Pretrained Model --- + # Retrieve the actual model from DataParallel wrapper if necessary + final_model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict() + save_pretrain_checkpoint(final_model_state, dataset_name, output_dir=pretrain_output_dir) # Pass state_dict directly - if all_preds is not None and all_targets is not None: - plot_confusion_matrix(all_preds.cpu().numpy(), all_targets.cpu().numpy(), num_classes, output_dir, dataset_name) - else: - print(f"Could not generate confusion matrix for {dataset_name} due to missing prediction data.") + # --- Generate Plots after Pretraining --- + print(f"\n--- Generating plots for {dataset_name} pretraining ---") + # Plot Loss/Accuracy Curves + plot_metrics(log_file_path, pretrain_output_dir, dataset_name) + # Generate Confusion Matrix (requires one last validation run) + print(f"Running final validation on {dataset_name} for Confusion Matrix...") + _, _, _, final_preds, final_targets = validate(model, val_loader, criterion, args.device, dataset_name=f"Final {dataset_name} Val") + class_names = [str(i) for i in range(num_classes)] # Generic class names + plot_confusion_matrix(final_preds, final_targets, class_names, pretrain_output_dir, dataset_name) - # Important: Return the model (potentially loaded with best weights) - return model + # Return the base model (without DataParallel wrapper if it was used) + return model.module if isinstance(model, nn.DataParallel) else model ##################################### # Training and Evaluation Functions ##################################### -def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, epoch, device, args, mixup_fn=None): +def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, epoch, num_epochs, device, dataset_name="Train", mixup_fn=None): """Train model for one epoch""" model.train() @@ -412,58 +365,73 @@ def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, epoch, top1 = AverageMeter() top5 = AverageMeter() - steps_per_epoch = len(train_loader) - pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Train]", - leave=False, ncols=100, unit="batch") # Changed leave to False for cleaner nested loops + # Initialize tqdm progress bar + pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [{dataset_name}]", + leave=True, ncols=100, unit="batch") - for batch_idx, (images, target) in enumerate(pbar): + steps_per_epoch = len(train_loader) # Get steps per epoch + + for i, (images, target) in enumerate(pbar): # Move data to device images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) # Apply mixup or cutmix if available + is_soft_target = False if mixup_fn is not None: images, target = mixup_fn(images, target) + is_soft_target = True # Target is now soft # Forward pass output = model(images) - # Handle cases where mixup changes target format - if mixup_fn is not None and len(target.shape) > 1: - loss = criterion(output, target) # SoftTargetCrossEntropy handles smoothed labels - # Accuracy calculation is ambiguous with mixup, often skipped or calculated differently - acc1, acc5 = [torch.tensor(0.0), torch.tensor(0.0)] # Placeholder - else: - loss = criterion(output, target) # Standard CE or LabelSmoothing - acc1, acc5 = accuracy(output, target, topk=(1, 5)) + # Calculate loss based on target type + loss = criterion(output, target) + # Measure accuracy and record loss + # Accuracy calculation is only valid if not using soft targets (mixup/cutmix) + batch_top1_avg = None + batch_top5_avg = None + if not is_soft_target: + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + top1.update(acc1[0].item(), images.size(0)) + top5.update(acc5[0].item(), images.size(0)) + batch_top1_avg = top1.avg # Use running average for display + batch_top5_avg = top5.avg # Update meters losses.update(loss.item(), images.size(0)) - if mixup_fn is None: # Only update accuracy if not using mixup - top1.update(acc1[0].item(), images.size(0)) - top5.update(acc5[0].item(), images.size(0)) # Backward pass and optimize optimizer.zero_grad() loss.backward() + # Gradient clipping (optional but often helpful) + # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() - # Adjust learning rate based on step, not epoch (important for cosine schedule with warmup) - scheduler.step() + # Step the scheduler *after* the optimizer step + # The LR scheduler expects step increments, adjust based on total steps + current_step = epoch * steps_per_epoch + i + scheduler.step(current_step) # Pass current step for LambdaLR-like schedulers + # Update progress bar - lr = scheduler.get_last_lr()[0] - pbar.set_postfix({ + lr = optimizer.param_groups[0]['lr'] # More reliable way to get current LR + postfix_dict = { 'Loss': f"{losses.avg:.4f}", - 'Acc@1': f"{top1.avg:.2f}%" if mixup_fn is None else "N/A", 'LR': f"{lr:.6f}" - }) + } + if batch_top1_avg is not None: + postfix_dict['Acc@1'] = f"{batch_top1_avg:.2f}%" + # if batch_top5_avg is not None: + # postfix_dict['Acc@5'] = f"{batch_top5_avg:.2f}%" + pbar.set_postfix(postfix_dict) + + # Return epoch averages (handle case where accuracy wasn't measured) + return losses.avg, top1.avg if top1.count > 0 else None, top5.avg if top5.count > 0 else None - # Return average metrics for the epoch - return losses.avg, top1.avg, top5.avg -def validate(model, val_loader, criterion, device, return_preds_targets=False): +def validate(model, val_loader, criterion, device, dataset_name="Validation"): """Evaluate model on validation set""" model.eval() @@ -471,46 +439,45 @@ def validate(model, val_loader, criterion, device, return_preds_targets=False): top1 = AverageMeter() top5 = AverageMeter() - all_preds_list = [] - all_targets_list = [] + all_preds = [] + all_targets = [] - pbar = tqdm(val_loader, desc="Validation", leave=False, ncols=100, unit="batch") # Changed leave to False + # Initialize tqdm progress bar + pbar = tqdm(val_loader, desc=f"[{dataset_name}]", leave=False, ncols=100, unit="batch") with torch.no_grad(): for images, target in pbar: + # Move data to device images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) + # Forward pass output = model(images) loss = criterion(output, target) + # Measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) + # Collect predictions for Confusion Matrix + _, predicted_indices = torch.max(output.data, 1) + all_preds.extend(predicted_indices.cpu().numpy()) + all_targets.extend(target.cpu().numpy()) + + + # Update meters losses.update(loss.item(), images.size(0)) top1.update(acc1[0].item(), images.size(0)) top5.update(acc5[0].item(), images.size(0)) - if return_preds_targets: - preds = torch.argmax(output, dim=1) - all_preds_list.append(preds.cpu()) # Move to CPU immediately - all_targets_list.append(target.cpu()) # Move to CPU immediately - - + # Update progress bar pbar.set_postfix({ 'Loss': f"{losses.avg:.4f}", 'Acc@1': f"{top1.avg:.2f}%", 'Acc@5': f"{top5.avg:.2f}%" }) - all_preds = None - all_targets = None - if return_preds_targets and len(all_preds_list) > 0: - all_preds = torch.cat(all_preds_list) - all_targets = torch.cat(all_targets_list) - - - # No need to print here if called during training loop, will be printed in main loop - # If called standalone (e.g., for final CM), the calling function should print + print(f"* [{dataset_name}]: Acc@1 {top1.avg:.3f}% Acc@5 {top5.avg:.3f}% Loss {losses.avg:.4f}") + # Return averages AND the collected predictions/targets return losses.avg, top1.avg, top5.avg, all_preds, all_targets ##################################### @@ -521,286 +488,287 @@ def main(args): # Set random seed for reproducibility torch.manual_seed(args.seed) np.random.seed(args.seed) + random.seed(args.seed) # Add random seed for python's random module if torch.cuda.is_available(): - torch.cuda.manual_seed(args.seed) - torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for speed - - # --- Setup Directories --- - base_output_dir = args.output_dir - experiment_output_dir = os.path.join(base_output_dir, args.tag) if args.tag else base_output_dir - log_dir = args.log_dir # Use dedicated log dir - os.makedirs(experiment_output_dir, exist_ok=True) - os.makedirs(log_dir, exist_ok=True) - checkpoint_dir = os.path.join(experiment_output_dir, 'checkpoints') # Subdir for checkpoints - os.makedirs(checkpoint_dir, exist_ok=True) + torch.cuda.manual_seed_all(args.seed) # Seed all GPUs + # These can sometimes slow down training or cause issues, enable if needed + # torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True # Usually speeds up training - print(f"Output Directory: {experiment_output_dir}") - print(f"Log Directory: {log_dir}") - print(f"Checkpoint Directory: {checkpoint_dir}") + # Create output directory + output_dir = os.path.join(args.output_dir, args.tag if args.tag else time.strftime("%Y%m%d-%H%M%S")) + os.makedirs(output_dir, exist_ok=True) + log_dir = os.path.join(args.log_dir, args.tag if args.tag else os.path.basename(output_dir)) # Log dir specific to this run + os.makedirs(log_dir, exist_ok=True) + print(f"Output directory: {output_dir}") + print(f"Log directory: {log_dir}") + print(f"Using device: {args.device}") - # --- Create Model --- + # Create model - start with ImageNet classes (1000) or a base number + # The head will be replaced before each training stage. print("Initializing model...") - model = get_model(model_name='swin_t', efficient=args.efficient) - print(f"Model: swin_t (Efficient: {args.efficient})") - - - # --- Pretraining Stages --- - if not args.skip_pretrain: - if args.pretrain_cifar: - print("\n>>> Starting CIFAR-100 Pretraining Stage <<<") - cifar_train_loader = get_cifar_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) - cifar_val_loader = get_cifar_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) - model = pretrain_on_dataset(model, cifar_train_loader, cifar_val_loader, - num_classes=100, args=args, dataset_name='cifar100', - output_dir=experiment_output_dir, log_dir=log_dir) - print("\n>>> Finished CIFAR-100 Pretraining Stage <<<") - - - if args.pretrain_caltech: - print("\n>>> Starting Caltech-256 Pretraining Stage <<<") - # If CIFAR pretraining happened, the model already has a head for 100 classes. - # If not, it has the original 1000 class head. `pretrain_on_dataset` handles replacement. - caltech_train_loader = get_caltech_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) - caltech_val_loader = get_caltech_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) - model = pretrain_on_dataset(model, caltech_train_loader, caltech_val_loader, - num_classes=257, args=args, dataset_name='caltech256', - output_dir=experiment_output_dir, log_dir=log_dir) - print("\n>>> Finished Caltech-256 Pretraining Stage <<<") - else: - print("Skipping all pretraining stages as requested.") - + model = get_model(model_name='swin_t', efficient=args.efficient, num_classes=1000) # Start with 1000 classes + + # === Pretraining Stage: CIFAR === + if args.pretrain_cifar and not args.skip_pretrain: + print("\n--- Starting CIFAR-100 Pretraining Stage ---") + print("Creating dataloaders for cifar100...") + cifar_train_loader = get_cifar_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) + cifar_val_loader = get_cifar_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) + model = pretrain_on_dataset( + model, cifar_train_loader, cifar_val_loader, + num_classes=100, args=args, dataset_name='cifar100', + output_dir=output_dir, log_dir=log_dir + ) + print("--- CIFAR-100 Pretraining Stage Complete ---") + + + # === Pretraining Stage: Caltech === + if args.pretrain_caltech and not args.skip_pretrain: + print("\n--- Starting Caltech-256 Pretraining Stage ---") + print("Creating dataloaders for caltech256...") + caltech_train_loader = get_caltech_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) + caltech_val_loader = get_caltech_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) + model = pretrain_on_dataset( + model, caltech_train_loader, caltech_val_loader, + num_classes=257, args=args, dataset_name='caltech256', + output_dir=output_dir, log_dir=log_dir + ) + print("--- Caltech-256 Pretraining Stage Complete ---") - # --- Final Training Stage: Tiny ImageNet --- - print("\n>>> Starting Final Training Stage: Tiny ImageNet <<<") - print("Creating dataloaders for Tiny ImageNet...") + # === Final Training Stage (Tiny ImageNet) === + print("\n--- Starting Final Training Stage (Tiny ImageNet) ---") + final_num_classes = 200 # Tiny ImageNet has 200 classes + print(f"Creating dataloaders for Tiny ImageNet ({final_num_classes} classes)...") train_loader, val_loader, mixup_fn = get_loaders( batch_size=args.batch_size, num_workers=args.workers, - img_size=MODEL_CONFIG["img_size"], + img_size=MODEL_CONFIG["img_size"], # Make sure this matches Swin-T expectation use_mixup=args.mixup ) - print("Replacing model head for Tiny ImageNet (200 classes)...") - model = replace_head(model, num_classes=200) - model = model.to(args.device) + # Replace head for the final dataset + model = replace_head(model, num_classes=final_num_classes) + model.to(args.device) + + # Handle DataParallel if multiple GPUs are available for the final stage + if torch.cuda.device_count() > 1 and not args.evaluate: # Don't wrap if only evaluating + print(f"Using {torch.cuda.device_count()} GPUs for final training stage.") + model = nn.DataParallel(model) + # Print model information num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"Number of trainable parameters (final stage): {num_params:,}") + print(f"Final Model - Number of trainable parameters: {num_params:,}") + + # Create optimizer + optimizer = create_optimizer(model.module if isinstance(model, nn.DataParallel) else model, args.lr, args.weight_decay) + + + # Create scheduler + steps_per_epoch_final = len(train_loader) + if steps_per_epoch_final == 0: + print("Error: Final train loader is empty!") + return - # Create optimizer and scheduler for the final stage - optimizer = create_optimizer(model, args.lr, args.weight_decay) - steps_per_epoch_main = len(train_loader) scheduler = create_scheduler( optimizer, num_epochs=args.epochs, - steps_per_epoch=steps_per_epoch_main, - base_lr=args.lr, + steps_per_epoch=steps_per_epoch_final, + base_lr=args.lr, # Pass base LR warmup_epochs=args.warmup_epochs, min_lr=args.min_lr ) - # Create loss function for the final stage - if args.mixup: + # Create loss function + if args.mixup and mixup_fn is not None: + print("Using Mixup/CutMix augmentation with SoftTargetCrossEntropy loss.") criterion = SoftTargetCrossEntropy().to(args.device) - print("Using Mixup/Cutmix augmentation with SoftTargetCrossEntropy loss.") else: + print(f"Using Label Smoothing Cross Entropy loss with smoothing={args.label_smoothing}.") criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing).to(args.device) - print(f"Using Label Smoothing Cross Entropy loss (smoothing={args.label_smoothing}).") - # Optionally resume from a final stage checkpoint + # Optionally resume from checkpoint start_epoch = 0 best_acc = 0.0 if args.resume: - resume_path = args.resume if os.path.isabs(args.resume) else os.path.join(checkpoint_dir, args.resume) - if os.path.isfile(resume_path): - print(f"Attempting to resume final stage training from: {resume_path}") - # Load optimizer and scheduler state when resuming main training - start_epoch, best_acc = load_checkpoint(model, optimizer, scheduler, resume_path, load_optimizer_scheduler=True) - print(f"Resumed final stage from epoch {start_epoch}. Previous best accuracy: {best_acc:.2f}%") - start_epoch = start_epoch # Checkpoint saves epoch+1, so start from the returned value - else: - print(f"Resume checkpoint not found at '{resume_path}'. Starting final training from scratch.") - - # Evaluation only mode for the final model + # Ensure the model head matches the checkpoint's expected classes before loading + # Note: This might require knowing the number of classes the checkpoint was saved with. + # If resuming a final stage checkpoint, the head should already be correct (200). + # If resuming a pretraining checkpoint, this logic needs adjustment. + # For simplicity, assume resume is for the final stage. + model_to_load = model.module if isinstance(model, nn.DataParallel) else model + start_epoch, best_acc = load_checkpoint(model_to_load, optimizer, scheduler, args.resume) + # Sync epoch for DataParallel case? Usually handled inside load_checkpoint if needed. + + # Evaluation only if args.evaluate: - print("--- Running Evaluation Only Mode ---") - eval_checkpoint_path = args.resume if args.resume else os.path.join(experiment_output_dir, 'model_best.pth') - if os.path.isfile(eval_checkpoint_path): - print(f"Loading model from: {eval_checkpoint_path} for evaluation...") - # Don't load optimizer/scheduler for evaluation - load_checkpoint(model, filename=eval_checkpoint_path, load_optimizer_scheduler=False) - print("Running validation...") - val_loss, val_acc1, val_acc5, all_preds, all_targets = validate( - model, val_loader, criterion, args.device, return_preds_targets=True - ) - print(f"\nEvaluation Results (Tiny ImageNet):") - print(f" Loss: {val_loss:.4f}") - print(f" Acc@1: {val_acc1:.2f}%") - print(f" Acc@5: {val_acc5:.2f}%") - - # Plot confusion matrix for evaluation - if all_preds is not None and all_targets is not None: - plot_confusion_matrix(all_preds.cpu().numpy(), all_targets.cpu().numpy(), 200, experiment_output_dir, "main_eval") - else: - print("Could not generate confusion matrix due to missing prediction data.") - - else: - print(f"Evaluation checkpoint '{eval_checkpoint_path}' not found. Cannot evaluate.") - return # Exit after evaluation + print("\n--- Running Evaluation Only ---") + if not args.resume: + print("Warning: Evaluating without loading a checkpoint (`--resume` not specified). Using initial model weights.") + _, val_acc1, _, final_preds, final_targets = validate(model, val_loader, criterion, args.device, dataset_name="Evaluation") + print(f"Evaluation Accuracy@1: {val_acc1:.3f}%") + # Generate plots for evaluation run + class_names = [str(i) for i in range(final_num_classes)] # Generic class names + eval_output_dir = os.path.join(output_dir, "evaluation") + os.makedirs(eval_output_dir, exist_ok=True) + plot_confusion_matrix(final_preds, final_targets, class_names, eval_output_dir, "tinyimagenet_eval") + # Cannot plot loss/acc curves without training history + return + # Print training configuration + print(f"\nStarting final training for {args.epochs} epochs (from epoch {start_epoch})") + print(f"Batch size: {args.batch_size}") + print(f"Initial Learning rate: {args.lr}") + print(f"Minimum Learning rate: {args.min_lr}") + print(f"Weight decay: {args.weight_decay}") + print(f"Using mixup: {args.mixup and mixup_fn is not None}") + print(f"Label smoothing: {args.label_smoothing if not (args.mixup and mixup_fn is not None) else 'N/A (using SoftTargetCE)'}") - # --- Main Training Loop --- - print(f"\n--- Starting Final Training Loop (Tiny ImageNet) for {args.epochs - start_epoch} epochs ---") - print(f"Batch size: {args.batch_size}, Initial LR: {args.lr}, Weight Decay: {args.weight_decay}, Mixup: {args.mixup}") + # --- Final Training loop --- + log_file_path_main = os.path.join(log_dir, "main_log.csv") for epoch in range(start_epoch, args.epochs): epoch_start = time.time() - print(f"\n--- Tiny ImageNet Epoch {epoch+1}/{args.epochs} ---") - # Train + # Train for one epoch train_loss, train_acc1, train_acc5 = train_one_epoch( - model, train_loader, criterion, optimizer, scheduler, epoch, args.device, args, mixup_fn - ) + model, train_loader, criterion, optimizer, scheduler, epoch, args.epochs, args.device, dataset_name="TinyImageNet Train", mixup_fn=mixup_fn) - # Validate - val_loss, val_acc1, val_acc5, _, _ = validate( # Don't need preds/targets here - model, val_loader, criterion, args.device, return_preds_targets=False - ) + # Evaluate on validation set + val_loss, val_acc1, val_acc5, _, _ = validate( # Discard preds/targets during epoch validation + model, val_loader, criterion, args.device, dataset_name="TinyImageNet Val") + # Calculate epoch time epoch_time = time.time() - epoch_start - # Check if current epoch is best + # Check if current model is the best is_best = val_acc1 > best_acc if is_best: - old_best = best_acc best_acc = val_acc1 - print(f"*** New Best Accuracy: {best_acc:.2f}% (Improved from {old_best:.2f}%) ***") - else: - print(f"Validation Acc@1: {val_acc1:.2f}% (Best: {best_acc:.2f}%)") + print(f"** New Best Val Acc@1: {best_acc:.3f}% **") + + # Log metrics to CSV + # Handle None for train_acc if mixup was used + log_metrics("main", epoch, + train_loss, train_acc1, train_acc5, + val_loss, val_acc1, val_acc5, + log_dir=log_dir) + + + # Print epoch summary + print(f"--- Epoch {epoch+1}/{args.epochs} Summary ---") + print(f" Time: {epoch_time:.2f}s") + train_acc1_str = f"{train_acc1:.2f}%" if train_acc1 is not None else "N/A (Mixup)" + train_acc5_str = f"{train_acc5:.2f}%" if train_acc5 is not None else "N/A (Mixup)" + print(f" Train: Loss {train_loss:.4f}, Acc@1 {train_acc1_str}, Acc@5 {train_acc5_str}") + print(f" Valid: Loss {val_loss:.4f}, Acc@1 {val_acc1:.2f}%, Acc@5 {val_acc5:.2f}%") + print(f" Best Valid Acc@1 so far: {best_acc:.2f}%") + print("-" * (len(f"--- Epoch {epoch+1}/{args.epochs} Summary ---"))) # Divider - # Log metrics to CSV for the main training stage - log_metrics("main", epoch, train_loss, train_acc1, train_acc5, val_loss, val_acc1, val_acc5, log_dir=log_dir) + # Retrieve the actual model state dict, handling DataParallel + model_state_to_save = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict() + # Save checkpoint periodically if (epoch + 1) % args.save_interval == 0: - save_checkpoint( - model, optimizer, scheduler, epoch + 1, val_acc1, - os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch+1}.pth') - ) + save_checkpoint( + model_state_to_save, optimizer, scheduler, epoch + 1, val_acc1, + os.path.join(output_dir, f'checkpoint_epoch{epoch+1}.pth') + ) - # Always save the best model + # Always save the best model based on validation accuracy if is_best: - save_checkpoint( - model, optimizer, scheduler, epoch + 1, best_acc, # Save best_acc here - os.path.join(experiment_output_dir, 'model_best.pth') # Save best model in parent dir - ) - - # Print epoch summary - print(f"Epoch {epoch+1} Summary | Time: {epoch_time:.2f}s | LR: {scheduler.get_last_lr()[0]:.6f}") - print(f" Train -> Loss: {train_loss:.4f}, Acc@1: {train_acc1:.2f}%" if not args.mixup else f" Train -> Loss: {train_loss:.4f}, Acc@1: N/A (Mixup)") - print(f" Valid -> Loss: {val_loss:.4f}, Acc@1: {val_acc1:.2f}%, Acc@5: {val_acc5:.2f}%") - + save_checkpoint( + model_state_to_save, optimizer, scheduler, epoch + 1, best_acc, # Save best_acc here + os.path.join(output_dir, 'model_best.pth') + ) - print(f"\n--- Finished Final Training Stage (Tiny ImageNet) ---") - print(f"Best validation accuracy achieved: {best_acc:.2f}%") + print(f"\nTraining complete. Best validation accuracy: {best_acc:.2f}%") - # --- Final Plotting and Confusion Matrix for Main Training --- - # Plot loss/accuracy curves for the main training - main_log_file = os.path.join(log_dir, "main_log.csv") - plot_metrics(main_log_file, experiment_output_dir, "main") + # --- Generate Final Plots after Training --- + print(f"\n--- Generating plots for final Tiny ImageNet training ---") + # Plot Loss/Accuracy Curves from the main log file + plot_metrics(log_file_path_main, output_dir, "tinyimagenet_main") - # Load the *best* model for the final confusion matrix - best_model_path = os.path.join(experiment_output_dir, 'model_best.pth') + # Load the best model checkpoint for final validation and CM + best_model_path = os.path.join(output_dir, 'model_best.pth') if os.path.exists(best_model_path): - print(f"Loading best model from {best_model_path} for final confusion matrix...") - # Create a fresh instance or reload into the current one - # Re-create model to ensure clean state if needed, though loading state_dict should be fine - final_model = get_model(model_name='swin_t', efficient=args.efficient) # Head already replaced - # final_model = replace_head(final_model, num_classes=200) - - load_checkpoint(final_model, filename=best_model_path, load_optimizer_scheduler=False) + print(f"Loading best model from {best_model_path} for final validation...") + # Re-create model instance and load state dict (necessary if DataParallel was used during training) + final_model = get_model(model_name='swin_t', efficient=args.efficient, num_classes=final_num_classes) + load_checkpoint(final_model, filename=best_model_path) # Load only model weights final_model.to(args.device) - final_model.eval() - - print("Running validation on best model for final confusion matrix...") - _, _, _, all_preds, all_targets = validate( - final_model, val_loader, criterion, args.device, return_preds_targets=True - ) + if torch.cuda.device_count() > 1: # Apply DataParallel if needed for validation + final_model = nn.DataParallel(final_model) + else: + print("Best model checkpoint not found. Using model from last epoch for final validation.") + final_model = model # Use the model from the last training epoch - if all_preds is not None and all_targets is not None: - plot_confusion_matrix(all_preds.cpu().numpy(), all_targets.cpu().numpy(), 200, experiment_output_dir, "main_best_model") - else: - print("Could not generate final confusion matrix due to missing prediction data.") + # Generate Confusion Matrix using the best (or last) model + print("Running final validation on Tiny ImageNet for Confusion Matrix...") + # Use the correct criterion for validation + final_val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing).to(args.device) # Use standard CE for final validation CM - else: - print(f"Best model checkpoint '{best_model_path}' not found. Cannot generate confusion matrix for the best model.") + _, _, _, final_preds, final_targets = validate( + final_model, val_loader, final_val_criterion, args.device, dataset_name="Final Best Model Val" + ) + class_names = [str(i) for i in range(final_num_classes)] # Generic class names + plot_confusion_matrix(final_preds, final_targets, class_names, output_dir, "tinyimagenet_main_best") - print("\n>>> All Training Stages Complete <<<") + print("--- Final Training Stage Complete ---") def parse_args(): - parser = argparse.ArgumentParser(description='Swin Transformer Training with Pretraining Options') + parser = argparse.ArgumentParser(description='Swin Transformer Training with Pretraining and Plotting') - # --- Model --- - parser.add_argument('--model-name', type=str, default='swin_t', help='Name of the model architecture (e.g., swin_t)') - parser.add_argument('--efficient', action='store_true', help='Use efficient model variant (if available)') - # parser.add_argument('--no-pretrained', action='store_true', help='Do not use ImageNet pretrained weights initially') + # Model parameters + parser.add_argument('--efficient', action='store_true', help='Use efficient model variant') - # --- Pre-training --- + # pre-training... parser.add_argument('--pretrain-cifar', action='store_true', help='Pretrain on CIFAR-100 first') - parser.add_argument('--pretrain-caltech', action='store_true', help='Pretrain on Caltech-256 (after CIFAR if specified, otherwise from ImageNet)') - parser.add_argument('--skip-pretrain', action='store_true', help='Skip all pretraining stages and train directly on Tiny ImageNet') - # parser.add_argument('--pretrain-epochs', type=int, default=30, help='Number of epochs for each pretraining stage (if different from main epochs)') # Optional: Separate epoch control - - # --- Main Training --- - parser.add_argument('--batch-size', type=int, default=TRAIN_CONFIG['batch_size'], help='Input batch size for training') - parser.add_argument('--epochs', type=int, default=TRAIN_CONFIG['epochs'], help='Number of epochs to train') - parser.add_argument('--lr', '--learning-rate', type=float, default=TRAIN_CONFIG['learning_rate'], help='Initial learning rate') - parser.add_argument('--min-lr', type=float, default=TRAIN_CONFIG['min_lr'], help='Minimum learning rate for scheduler') - parser.add_argument('--warmup-epochs', type=int, default=TRAIN_CONFIG['warmup_epochs'], help='Number of warmup epochs') - parser.add_argument('--weight-decay', type=float, default=TRAIN_CONFIG['weight_decay'], help='Optimizer weight decay') - parser.add_argument('--label-smoothing', type=float, default=TRAIN_CONFIG['label_smoothing'], help='Label smoothing factor (if not using mixup)') - parser.add_argument('--mixup', action='store_true', help='Use mixup and cutmix augmentation (disables label smoothing)') - - # --- Data & Device --- - parser.add_argument('--img-size', type=int, default=MODEL_CONFIG['img_size'], help='Input image size') # Make img_size configurable if needed - parser.add_argument('--workers', type=int, default=4, help='Number of data loading workers') - parser.add_argument('--device', default=TRAIN_CONFIG['device'], help='Device to use (e.g., "cuda", "cpu")') - - # --- Checkpointing & Logging --- - parser.add_argument('--output-dir', default=TRAIN_CONFIG['output_dir'], help='Base directory to save checkpoints and logs') - parser.add_argument('--log-dir', default=TRAIN_CONFIG['log_dir'], help='Directory within output-dir to save CSV logs and plots') - parser.add_argument('--tag', default='', type=str, help='Optional tag for experiment directory name') - parser.add_argument('--save-interval', type=int, default=TRAIN_CONFIG['save_interval'], help='Save checkpoint every N epochs during main training') - parser.add_argument('--resume', default='', type=str, metavar='PATH', help='Path to latest checkpoint to resume main training (or for evaluation)') - - # --- Misc --- - parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility') - parser.add_argument('--evaluate', action='store_true', help='Perform evaluation only on the validation set (requires --resume or finds model_best.pth)') + parser.add_argument('--pretrain-caltech', action='store_true', help='Pretrain on Caltech-256 (after CIFAR if specified)') + parser.add_argument('--skip-pretrain', action='store_true', help='Skip all pretraining stages') + + # Training parameters + parser.add_argument('--batch-size', type=int, default=TRAIN_CONFIG['batch_size'], metavar='N', help=f'Input batch size (default: {TRAIN_CONFIG["batch_size"]})') + parser.add_argument('--epochs', type=int, default=TRAIN_CONFIG['epochs'], metavar='N', help=f'Number of epochs to train (default: {TRAIN_CONFIG["epochs"]})') + parser.add_argument('--lr', '--learning-rate', type=float, default=TRAIN_CONFIG['learning_rate'], metavar='LR', help=f'Initial learning rate (default: {TRAIN_CONFIG["learning_rate"]})') + parser.add_argument('--min-lr', type=float, default=TRAIN_CONFIG['min_lr'], metavar='MINLR', help=f'Minimum learning rate for cosine scheduler (default: {TRAIN_CONFIG["min_lr"]})') + parser.add_argument('--warmup-epochs', type=int, default=TRAIN_CONFIG['warmup_epochs'], metavar='N', help=f'Number of warmup epochs (default: {TRAIN_CONFIG["warmup_epochs"]})') + parser.add_argument('--weight-decay', type=float, default=TRAIN_CONFIG['weight_decay'], metavar='WD', help=f'Weight decay (default: {TRAIN_CONFIG["weight_decay"]})') + parser.add_argument('--label-smoothing', type=float, default=TRAIN_CONFIG['label_smoothing'], metavar='LS', help=f'Label smoothing factor (default: {TRAIN_CONFIG["label_smoothing"]})') + parser.add_argument('--device', default=TRAIN_CONFIG['device'], help='Device to use (cuda or cpu)') + parser.add_argument('--mixup', action='store_true', default=False, help='Use mixup and cutmix augmentation (requires timm mixup implementation in get_loaders)') + + # Data loading + parser.add_argument('--workers', type=int, default=4, metavar='N', help='Number of data loading workers (default: 4)') + + # Checkpointing + parser.add_argument('--output-dir', default=TRAIN_CONFIG['output_dir'], help='Path to save output (checkpoints, plots)') + parser.add_argument('--tag', default='', help='Experiment tag to append to output/log directories') + parser.add_argument('--save-interval', type=int, default=TRAIN_CONFIG['save_interval'], metavar='N', help='Save checkpoint every N epochs (default: 10)') + parser.add_argument('--resume', default='', type=str, metavar='PATH', help='Resume final training stage from checkpoint path') + + # logs + parser.add_argument('--log-dir', default='logs', type=str, help='Directory to save all training logs (.csv files)') + + # Misc + parser.add_argument('--seed', type=int, default=42, metavar='S', help='Random seed (default: 42)') + parser.add_argument('--evaluate', action='store_true', help='Perform evaluation only (requires --resume)') + # parser.add_argument('--no-plots', action='store_true', help='Disable generating plots') # Optional: Add if you want to disable plots args = parser.parse_args() - # Set device based on argument + # Set device explicitly args.device = torch.device(args.device if torch.cuda.is_available() else "cpu") - # If mixup is used, disable label smoothing effect by setting it to 0 - if args.mixup: - args.label_smoothing = 0.0 - print("Mixup enabled, label smoothing set to 0.") - - # Ensure log_dir is inside output_dir unless absolute path is given - if not os.path.isabs(args.log_dir): - args.log_dir = os.path.join(args.output_dir, args.tag if args.tag else '', args.log_dir) - - return args - if __name__ == '__main__': + # Need 'random' for seeding + import random args = parse_args() - # Ensure necessary directories exist based on final paths - os.makedirs(args.log_dir, exist_ok=True) main(args) From 30940ffe7d15e06098523c9b902f527e2516d917 Mon Sep 17 00:00:00 2001 From: us3r247 <81853860+us3r247@users.noreply.github.com> Date: Wed, 16 Apr 2025 03:04:31 +0530 Subject: [PATCH 5/7] fixed dataset path --- cifar_data.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cifar_data.py b/cifar_data.py index e3964f2..7b96466 100644 --- a/cifar_data.py +++ b/cifar_data.py @@ -3,11 +3,18 @@ from torchvision import datasets from torchvision.transforms import v2 from torch.utils.data import random_split,DataLoader +import os -# PATH_TO_CIFAR100 = "path/to/download/cifar100" TODO: change this while training +data_dir = os.path.join(os.getcwd(),"data") -PATH_TO_CIFAR100 = "/mnt/769EC2439EC1FB9D/vsc_projs/cifar100" +if os.path.isdir(os.path.join(data_dir,"cifar100")): + PATH_TO_CIFAR100 = os.path.join(data_dir,"cifar100") +else: + os.makedirs(os.path.join(data_dir,"cifar100")) + PATH_TO_CIFAR100 = os.path.join(data_dir,"cifar100") + +# PATH_TO_CIFAR100 = "/mnt/769EC2439EC1FB9D/vsc_projs/cifar100" transforms = v2.Compose([ v2.PILToTensor(), From a1d3f549aac9aba25e06f852cdb86ebfecb98b47 Mon Sep 17 00:00:00 2001 From: us3r247 <81853860+us3r247@users.noreply.github.com> Date: Wed, 16 Apr 2025 03:10:50 +0530 Subject: [PATCH 6/7] fixed data --- caltech_data.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/caltech_data.py b/caltech_data.py index 46b91a1..044c89a 100644 --- a/caltech_data.py +++ b/caltech_data.py @@ -9,9 +9,12 @@ from torchvision.datasets import ImageFolder from torchvision.datasets.folder import default_loader -# PATH_TO_CALTECH256 = "path/to/download/caltech256" TODO: change this while trainign... -PATH_TO_CALTECH256 = "/mnt/769EC2439EC1FB9D/vsc_projs/caltech256" +data_dir = os.path.join(os.getcwd(),"data") + +PATH_TO_CALTECH256 = os.path.join(data_dir,"caltech256") + +# PATH_TO_CALTECH256 = "/mnt/769EC2439EC1FB9D/vsc_projs/caltech256" From 3fbdc100743f278b6c5494db473d2b330dae218a Mon Sep 17 00:00:00 2001 From: Yash-Agarwal-BITS Date: Wed, 16 Apr 2025 03:38:47 +0530 Subject: [PATCH 7/7] Revert "Update train.py" This reverts commit 8a56c1c809ed08440a1f14089ab7c24f07a21757. --- train.py | 862 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 447 insertions(+), 415 deletions(-) diff --git a/train.py b/train.py index f8089ee..4abd4be 100644 --- a/train.py +++ b/train.py @@ -2,7 +2,6 @@ import time import argparse import numpy as np -import pandas as pd # Added for reading logs import torch import torch.nn as nn import torch.nn.functional as F @@ -10,7 +9,9 @@ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy import csv from tqdm import tqdm +import pandas as pd # Added for reading logs import matplotlib.pyplot as plt # Added for plotting +import seaborn as sns # Added for confusion matrix styling from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay # Added for CM # Import custom modules @@ -22,7 +23,7 @@ # Training Configuration TRAIN_CONFIG = { "batch_size": 128, - "epochs": 100, + "epochs": 100, # Reduced for faster demonstration if needed "learning_rate": 5e-4, "min_lr": 5e-6, "weight_decay": 0.05, @@ -31,7 +32,8 @@ "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), "output_dir": "output", "log_interval": 20, - "save_interval": 10, + "save_interval": 10, # Reduced for faster demonstration + "log_dir": "logs", # Added log dir to config } ##################################### @@ -53,8 +55,10 @@ def update(self, val, n=1): self.val = val self.sum += val * n self.count += n - # Prevent division by zero if count is 0 - self.avg = self.sum / self.count if self.count > 0 else 0 + if self.count > 0: + self.avg = self.sum / self.count + else: + self.avg = 0 def accuracy(output, target, topk=(1,)): @@ -62,6 +66,8 @@ def accuracy(output, target, topk=(1,)): with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) + if batch_size == 0: + return [torch.tensor(0.0) for _ in topk] _, pred = output.topk(maxk, 1, True, True) pred = pred.t() @@ -81,13 +87,18 @@ def lr_lambda(current_step): progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) - # Ensure the final learning rate is at least min_lr + # Ensure the final learning rate doesn't go below min_lr return max(min_lr / base_lr, cosine_decay) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + def save_checkpoint(model, optimizer, scheduler, epoch, accuracy, filename): """Save model checkpoint""" + if not os.path.exists(os.path.dirname(filename)): + os.makedirs(os.path.dirname(filename), exist_ok=True) + print(f"Created directory: {os.path.dirname(filename)}") + state = { 'epoch': epoch, 'model': model.state_dict(), @@ -100,51 +111,55 @@ def save_checkpoint(model, optimizer, scheduler, epoch, accuracy, filename): print(f"Checkpoint saved to {filename}") def save_pretrain_checkpoint(model, dataset_name, output_dir='checkpoints'): + """Save model checkpoint after pretraining""" os.makedirs(output_dir, exist_ok=True) path = os.path.join(output_dir, f'{dataset_name}_checkpoint.pth') + # Save only the model state dict for pretraining checkpoints torch.save(model.state_dict(), path) print(f"Saved pretraining checkpoint for {dataset_name} at: {path}") -def load_checkpoint(model, optimizer=None, scheduler=None, filename=None): +def load_checkpoint(model, optimizer=None, scheduler=None, filename=None, load_optimizer_scheduler=True): """Load checkpoint from file""" - if not os.path.isfile(filename): + if not filename or not os.path.isfile(filename): print(f"No checkpoint found at {filename}") return 0, 0.0 print(f"Loading checkpoint from {filename}") checkpoint = torch.load(filename, map_location='cpu') - # Adjust for potential DataParallel prefix 'module.' - state_dict = checkpoint['model'] - new_state_dict = {} - for k, v in state_dict.items(): - name = k[7:] if k.startswith('module.') else k # remove `module.` prefix - new_state_dict[name] = v - model.load_state_dict(new_state_dict) - - - if optimizer is not None and 'optimizer' in checkpoint: - try: - optimizer.load_state_dict(checkpoint['optimizer']) - except ValueError as e: - print(f"Could not load optimizer state: {e}. This might happen if model structure changed.") + # Handle both full checkpoints and state_dict-only checkpoints + if 'model' in checkpoint: + model.load_state_dict(checkpoint['model']) + else: + # Assume it's just a state_dict + model.load_state_dict(checkpoint) + # If only state_dict loaded, cannot resume optimizer/scheduler/epoch + print("Loaded model state_dict only. Cannot resume optimizer, scheduler, or epoch.") + return 0, checkpoint.get('accuracy', 0.0) # Return 0 epoch, try to get accuracy + epoch = checkpoint.get('epoch', 0) + accuracy = checkpoint.get('accuracy', 0.0) - if scheduler is not None and 'scheduler' in checkpoint and checkpoint['scheduler'] is not None: - try: - scheduler.load_state_dict(checkpoint['scheduler']) - except KeyError as e: - print(f"Could not load scheduler state: {e}. This might happen if scheduler type changed.") + if load_optimizer_scheduler: + if optimizer is not None and 'optimizer' in checkpoint: + try: + optimizer.load_state_dict(checkpoint['optimizer']) + except Exception as e: + print(f"Could not load optimizer state: {e}. Continuing without loading optimizer.") - start_epoch = checkpoint.get('epoch', 0) - accuracy = checkpoint.get('accuracy', 0.0) + if scheduler is not None and 'scheduler' in checkpoint and checkpoint['scheduler'] is not None: + try: + scheduler.load_state_dict(checkpoint['scheduler']) + except Exception as e: + print(f"Could not load scheduler state: {e}. Continuing without loading scheduler.") - print(f"Checkpoint loaded. Resuming from epoch {start_epoch} with validation accuracy {accuracy:.2f}%") - return start_epoch, accuracy + else: + print("Skipping loading optimizer and scheduler state.") + return epoch, accuracy def create_optimizer(model, lr, weight_decay): """Create optimizer for model""" @@ -170,6 +185,7 @@ def create_scheduler(optimizer, num_epochs, steps_per_epoch, base_lr, warmup_epo def log_metrics(dataset_name, epoch, train_loss, train_acc1, train_acc5, val_loss, val_acc1, val_acc5, log_dir="logs"): + """Logs training and validation metrics to a CSV file.""" os.makedirs(log_dir, exist_ok=True) log_file = os.path.join(log_dir, f"{dataset_name}_log.csv") write_header = not os.path.exists(log_file) @@ -183,181 +199,212 @@ def log_metrics(dataset_name, epoch, train_loss, train_acc1, train_acc5, val_los "val_loss", "val_acc1", "val_acc5" ]) writer.writerow([ - epoch + 1, - train_loss if train_loss is not None else 'N/A', # Handle None cases - train_acc1 if train_acc1 is not None else 'N/A', - train_acc5 if train_acc5 is not None else 'N/A', - val_loss, val_acc1, val_acc5 + epoch + 1, # Log 1-based epoch + f"{train_loss:.4f}" if train_loss is not None else "N/A", + f"{train_acc1:.2f}" if train_acc1 is not None else "N/A", + f"{train_acc5:.2f}" if train_acc5 is not None else "N/A", + f"{val_loss:.4f}" if val_loss is not None else "N/A", + f"{val_acc1:.2f}" if val_acc1 is not None else "N/A", + f"{val_acc5:.2f}" if val_acc5 is not None else "N/A" ]) def replace_head(model, num_classes): - in_features = model.head.in_features # Get input features from existing head - model.head = nn.Linear(in_features, num_classes) - print(f"Replaced model head for {num_classes} classes.") + """Replaces the classification head of the model.""" + in_features = 0 + if hasattr(model, 'head') and hasattr(model.head, 'in_features'): + in_features = model.head.in_features + elif hasattr(model, 'fc') and hasattr(model.fc, 'in_features'): # common alternative name + in_features = model.fc.in_features + elif hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear): # Another common name + in_features = model.classifier.in_features + elif hasattr(model, 'num_features'): # Timm models often have this property + in_features = model.num_features + else: + raise AttributeError("Cannot determine the input features of the model's classification head. Tried 'head', 'fc', 'classifier'.") + + # Replace the head + if hasattr(model, 'head'): + model.head = nn.Linear(in_features, num_classes) + elif hasattr(model, 'fc'): + model.fc = nn.Linear(in_features, num_classes) + elif hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear): + model.classifier = nn.Linear(in_features, num_classes) + else: + # Fallback for models where head is not explicitly named 'head', 'fc' or 'classifier' + # This might need adjustment based on the specific architecture if get_model returns something unusual + print("Warning: Replacing head using a generic approach based on Timm's num_features. Ensure this is correct for the model.") + model.head = nn.Linear(in_features, num_classes) # Assume we can add a 'head' attribute + + print(f"Replaced model head with a new one for {num_classes} classes.") return model ##################################### -# Plotting Utilities # +# Plotting Functions ##################################### def plot_metrics(log_file, output_dir, dataset_name): - """Plots training/validation loss and accuracy from log file.""" + """Plots loss and accuracy curves from a log file.""" + if not os.path.exists(log_file): + print(f"Log file not found: {log_file}. Skipping plotting.") + return + try: df = pd.read_csv(log_file) - df.replace('N/A', np.nan, inplace=True) # Replace 'N/A' with NaN - df = df.astype({ # Convert relevant columns to numeric, errors='coerce' turns failures into NaN - 'train_loss': float, 'train_acc1': float, 'train_acc5': float, - 'val_loss': float, 'val_acc1': float, 'val_acc5': float - }, errors='coerce') - - - epochs = df['epoch'] - - plt.style.use('seaborn-v0_8-grid') # Use a nice style - fig, axes = plt.subplots(1, 2, figsize=(15, 5)) - - # Plot Loss - axes[0].plot(epochs, df['train_loss'], 'bo-', label='Train Loss') - axes[0].plot(epochs, df['val_loss'], 'ro-', label='Validation Loss') - axes[0].set_title(f'{dataset_name} - Loss vs. Epochs') - axes[0].set_xlabel('Epochs') - axes[0].set_ylabel('Loss') - axes[0].legend() - axes[0].grid(True) - - # Plot Accuracy - if df['train_acc1'].notna().any(): # Only plot train acc if data exists - axes[1].plot(epochs, df['train_acc1'], 'bo-', label='Train Accuracy@1') - axes[1].plot(epochs, df['val_acc1'], 'ro-', label='Validation Accuracy@1') - # Optionally plot Acc@5 - # if df['train_acc5'].notna().any(): - # axes[1].plot(epochs, df['train_acc5'], 'b--', label='Train Accuracy@5') - # axes[1].plot(epochs, df['val_acc5'], 'r--', label='Validation Accuracy@5') - axes[1].set_title(f'{dataset_name} - Accuracy vs. Epochs') - axes[1].set_xlabel('Epochs') - axes[1].set_ylabel('Accuracy (%)') - axes[1].legend() - axes[1].grid(True) - - plt.tight_layout() - plot_filename = os.path.join(output_dir, f'{dataset_name}_metrics_plot.png') - plt.savefig(plot_filename) - print(f"Metrics plot saved to {plot_filename}") - plt.close(fig) # Close the figure to free memory - - except FileNotFoundError: - print(f"Log file not found at {log_file}, skipping metrics plot.") except Exception as e: - print(f"Could not generate metrics plot: {e}") + print(f"Error reading log file {log_file}: {e}. Skipping plotting.") + return + if df.empty: + print(f"Log file {log_file} is empty. Skipping plotting.") + return -def plot_confusion_matrix(all_preds, all_targets, class_names, output_dir, dataset_name): - """Plots the confusion matrix.""" - if not all_preds or not all_targets: - print("No predictions or targets found, skipping confusion matrix.") + plt.style.use('seaborn-v0_8-grid') # Use a nice style + fig, ax1 = plt.subplots(figsize=(12, 6)) + + # Plot Loss + color = 'tab:red' + ax1.set_xlabel('Epoch') + ax1.set_ylabel('Loss', color=color) + if 'train_loss' in df.columns and pd.to_numeric(df['train_loss'], errors='coerce').notna().any(): + ax1.plot(df['epoch'], pd.to_numeric(df['train_loss'], errors='coerce'), label='Train Loss', color=color, linestyle='--') + if 'val_loss' in df.columns and pd.to_numeric(df['val_loss'], errors='coerce').notna().any(): + ax1.plot(df['epoch'], pd.to_numeric(df['val_loss'], errors='coerce'), label='Validation Loss', color=color) + ax1.tick_params(axis='y', labelcolor=color) + ax1.legend(loc='upper left') + + # Plot Accuracy + ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis + color = 'tab:blue' + ax2.set_ylabel('Accuracy (%)', color=color) + if 'train_acc1' in df.columns and pd.to_numeric(df['train_acc1'], errors='coerce').notna().any(): + ax2.plot(df['epoch'], pd.to_numeric(df['train_acc1'], errors='coerce'), label='Train Acc@1', color=color, linestyle='--') + if 'val_acc1' in df.columns and pd.to_numeric(df['val_acc1'], errors='coerce').notna().any(): + ax2.plot(df['epoch'], pd.to_numeric(df['val_acc1'], errors='coerce'), label='Validation Acc@1', color=color) + ax2.tick_params(axis='y', labelcolor=color) + ax2.legend(loc='lower left') + + plt.title(f'{dataset_name} - Training & Validation Metrics') + fig.tight_layout() # otherwise the right y-label is slightly clipped + + # Save plot + plot_filename = os.path.join(output_dir, f"{dataset_name}_metrics_plot.png") + plt.savefig(plot_filename) + print(f"Metrics plot saved to {plot_filename}") + plt.close(fig) # Close the figure to free memory + +def plot_confusion_matrix(all_preds, all_targets, num_classes, output_dir, dataset_name): + """Computes and plots the confusion matrix.""" + if all_preds is None or all_targets is None: + print(f"No prediction data available for {dataset_name}. Skipping confusion matrix.") return - try: - cm = confusion_matrix(all_targets, all_preds) - # Adjust figure size based on number of classes - figsize = max(8, len(class_names) // 6) # Heuristic for figure size - fig, ax = plt.subplots(figsize=(figsize, figsize)) + cm = confusion_matrix(all_targets, all_preds, labels=np.arange(num_classes)) - disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names) - disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation='vertical', values_format='d') + # Determine figure size based on number of classes + figsize = max(8, num_classes // 5) # Adjust divisor as needed - ax.set_title(f'{dataset_name} - Confusion Matrix') - plt.tight_layout() # Adjust layout to prevent overlap - cm_filename = os.path.join(output_dir, f'{dataset_name}_confusion_matrix.png') - plt.savefig(cm_filename) - print(f"Confusion matrix saved to {cm_filename}") - plt.close(fig) # Close the figure + fig, ax = plt.subplots(figsize=(figsize, figsize)) + disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=np.arange(num_classes)) + + # Determine whether to show values based on matrix size + show_values = num_classes <= 30 # Only show values for smaller matrices + + disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation='vertical', values_format='d' if show_values else None) # Only show numbers if show_values is True + + plt.title(f'{dataset_name} - Confusion Matrix') + plt.tight_layout() + + # Save plot + cm_filename = os.path.join(output_dir, f"{dataset_name}_confusion_matrix.png") + plt.savefig(cm_filename) + print(f"Confusion matrix saved to {cm_filename}") + plt.close(fig) # Close the figure - except Exception as e: - print(f"Could not generate confusion matrix: {e}") ##################################### -# Pretraining Function # +# Pretraining Function ##################################### def pretrain_on_dataset(model, train_loader, val_loader, num_classes, args, dataset_name, output_dir, log_dir): - model = replace_head(model, num_classes) + """Pretrains the model on a given dataset (CIFAR or Caltech).""" + print(f"\n=== Pretraining Stage: {dataset_name} ===") + model = replace_head(model, num_classes) # Ensure head matches dataset model.to(args.device) - # Handle DataParallel if multiple GPUs are available - if torch.cuda.device_count() > 1: - print(f"Using {torch.cuda.device_count()} GPUs for pretraining on {dataset_name}") - model = nn.DataParallel(model) - optimizer = create_optimizer(model, args.lr, args.weight_decay) - # Ensure steps_per_epoch is calculated correctly steps_per_epoch = len(train_loader) - if steps_per_epoch == 0: - print(f"Warning: train_loader for {dataset_name} is empty!") - return model # Cannot train with empty loader - scheduler = create_scheduler( optimizer, - num_epochs=args.epochs, + num_epochs=args.epochs, # Use main epoch count for pretraining? Or specific pretrain epochs? Using main for now. steps_per_epoch=steps_per_epoch, - base_lr=args.lr, # Pass base LR + base_lr=args.lr, warmup_epochs=args.warmup_epochs, min_lr=args.min_lr ) + # Use label smoothing for pretraining as well criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing).to(args.device) - print(f"\n=== Pretraining on {dataset_name} ({num_classes} classes) ===") best_acc = 0.0 - pretrain_output_dir = os.path.join(output_dir, f"pretrain_{dataset_name}") - os.makedirs(pretrain_output_dir, exist_ok=True) - log_file_path = os.path.join(log_dir, f"{dataset_name}_log.csv") - - # --- Training Loop --- - for epoch in range(args.epochs): + for epoch in range(args.epochs): # Use same number of epochs as main training? + print(f"\n--- {dataset_name} Epoch {epoch+1}/{args.epochs} ---") train_loss, train_acc1, train_acc5 = train_one_epoch( - model, train_loader, criterion, optimizer, scheduler, epoch, args.epochs, args.device, dataset_name=f"Pretrain {dataset_name}", mixup_fn=None # No mixup for pretrain + model, train_loader, criterion, optimizer, scheduler, epoch, args.device, args=args ) - val_loss, val_acc1, val_acc5, _, _ = validate( # Discard preds/targets here - model, val_loader, criterion, args.device, dataset_name=f"Pretrain {dataset_name}" + val_loss, val_acc1, val_acc5, _, _ = validate( # Get metrics only during training loop + model, val_loader, criterion, args.device, return_preds_targets=False ) - best_acc = max(val_acc1, best_acc) - print(f"[{dataset_name}] Epoch {epoch+1}: Val Acc@1={val_acc1:.2f}% | Best Val Acc@1={best_acc:.2f}%") - - # Log metrics + # Log metrics for this pretraining stage log_metrics(dataset_name, epoch, train_loss, train_acc1, train_acc5, val_loss, val_acc1, val_acc5, log_dir=log_dir) - # Save checkpoint (optional, can save best only) - # save_checkpoint(model.module if isinstance(model, nn.DataParallel) else model, optimizer, scheduler, epoch + 1, val_acc1, - # os.path.join(pretrain_output_dir, f'ckpt_epoch_{epoch+1}.pth')) + # Save best model based on validation accuracy for this stage + if val_acc1 > best_acc: + best_acc = val_acc1 + save_pretrain_checkpoint(model, dataset_name, output_dir=os.path.join(output_dir, 'checkpoints')) + print(f"[{dataset_name}] New best accuracy: {best_acc:.2f}% (Epoch {epoch+1}). Checkpoint saved.") + else: + print(f"[{dataset_name}] Epoch {epoch+1}: Acc@1={val_acc1:.2f}% | Best={best_acc:.2f}%") + - # --- Save Final/Best Pretrained Model --- - # Retrieve the actual model from DataParallel wrapper if necessary - final_model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict() - save_pretrain_checkpoint(final_model_state, dataset_name, output_dir=pretrain_output_dir) # Pass state_dict directly + print(f"--- Finished Pretraining on {dataset_name} ---") + + # --- Plotting and Final Validation for this Stage --- + log_file = os.path.join(log_dir, f"{dataset_name}_log.csv") + plot_metrics(log_file, output_dir, dataset_name) + + # Load the best checkpoint for this stage for final validation and CM + best_checkpoint_path = os.path.join(output_dir, 'checkpoints', f'{dataset_name}_checkpoint.pth') + if os.path.exists(best_checkpoint_path): + print(f"Loading best {dataset_name} model for final validation...") + # Load only state_dict, don't need optimizer/scheduler here + load_checkpoint(model, filename=best_checkpoint_path, load_optimizer_scheduler=False) + else: + print(f"Warning: Best checkpoint {best_checkpoint_path} not found. Using model from last epoch for validation.") + + + print(f"Running final validation on {dataset_name} to generate Confusion Matrix...") + final_val_loss, final_val_acc1, final_val_acc5, all_preds, all_targets = validate( + model, val_loader, criterion, args.device, return_preds_targets=True + ) + print(f"Final {dataset_name} Validation: Loss={final_val_loss:.4f}, Acc@1={final_val_acc1:.2f}%, Acc@5={final_val_acc5:.2f}%") - # --- Generate Plots after Pretraining --- - print(f"\n--- Generating plots for {dataset_name} pretraining ---") - # Plot Loss/Accuracy Curves - plot_metrics(log_file_path, pretrain_output_dir, dataset_name) + if all_preds is not None and all_targets is not None: + plot_confusion_matrix(all_preds.cpu().numpy(), all_targets.cpu().numpy(), num_classes, output_dir, dataset_name) + else: + print(f"Could not generate confusion matrix for {dataset_name} due to missing prediction data.") - # Generate Confusion Matrix (requires one last validation run) - print(f"Running final validation on {dataset_name} for Confusion Matrix...") - _, _, _, final_preds, final_targets = validate(model, val_loader, criterion, args.device, dataset_name=f"Final {dataset_name} Val") - class_names = [str(i) for i in range(num_classes)] # Generic class names - plot_confusion_matrix(final_preds, final_targets, class_names, pretrain_output_dir, dataset_name) - # Return the base model (without DataParallel wrapper if it was used) - return model.module if isinstance(model, nn.DataParallel) else model + # Important: Return the model (potentially loaded with best weights) + return model ##################################### # Training and Evaluation Functions ##################################### -def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, epoch, num_epochs, device, dataset_name="Train", mixup_fn=None): +def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, epoch, device, args, mixup_fn=None): """Train model for one epoch""" model.train() @@ -365,73 +412,58 @@ def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, epoch, top1 = AverageMeter() top5 = AverageMeter() - # Initialize tqdm progress bar - pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [{dataset_name}]", - leave=True, ncols=100, unit="batch") - - steps_per_epoch = len(train_loader) # Get steps per epoch + steps_per_epoch = len(train_loader) + pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Train]", + leave=False, ncols=100, unit="batch") # Changed leave to False for cleaner nested loops - for i, (images, target) in enumerate(pbar): + for batch_idx, (images, target) in enumerate(pbar): # Move data to device images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) # Apply mixup or cutmix if available - is_soft_target = False if mixup_fn is not None: images, target = mixup_fn(images, target) - is_soft_target = True # Target is now soft # Forward pass output = model(images) - # Calculate loss based on target type - loss = criterion(output, target) + # Handle cases where mixup changes target format + if mixup_fn is not None and len(target.shape) > 1: + loss = criterion(output, target) # SoftTargetCrossEntropy handles smoothed labels + # Accuracy calculation is ambiguous with mixup, often skipped or calculated differently + acc1, acc5 = [torch.tensor(0.0), torch.tensor(0.0)] # Placeholder + else: + loss = criterion(output, target) # Standard CE or LabelSmoothing + acc1, acc5 = accuracy(output, target, topk=(1, 5)) - # Measure accuracy and record loss - # Accuracy calculation is only valid if not using soft targets (mixup/cutmix) - batch_top1_avg = None - batch_top5_avg = None - if not is_soft_target: - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - top1.update(acc1[0].item(), images.size(0)) - top5.update(acc5[0].item(), images.size(0)) - batch_top1_avg = top1.avg # Use running average for display - batch_top5_avg = top5.avg # Update meters losses.update(loss.item(), images.size(0)) + if mixup_fn is None: # Only update accuracy if not using mixup + top1.update(acc1[0].item(), images.size(0)) + top5.update(acc5[0].item(), images.size(0)) # Backward pass and optimize optimizer.zero_grad() loss.backward() - # Gradient clipping (optional but often helpful) - # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() - # Step the scheduler *after* the optimizer step - # The LR scheduler expects step increments, adjust based on total steps - current_step = epoch * steps_per_epoch + i - scheduler.step(current_step) # Pass current step for LambdaLR-like schedulers - + # Adjust learning rate based on step, not epoch (important for cosine schedule with warmup) + scheduler.step() # Update progress bar - lr = optimizer.param_groups[0]['lr'] # More reliable way to get current LR - postfix_dict = { + lr = scheduler.get_last_lr()[0] + pbar.set_postfix({ 'Loss': f"{losses.avg:.4f}", + 'Acc@1': f"{top1.avg:.2f}%" if mixup_fn is None else "N/A", 'LR': f"{lr:.6f}" - } - if batch_top1_avg is not None: - postfix_dict['Acc@1'] = f"{batch_top1_avg:.2f}%" - # if batch_top5_avg is not None: - # postfix_dict['Acc@5'] = f"{batch_top5_avg:.2f}%" - pbar.set_postfix(postfix_dict) - - # Return epoch averages (handle case where accuracy wasn't measured) - return losses.avg, top1.avg if top1.count > 0 else None, top5.avg if top5.count > 0 else None + }) + # Return average metrics for the epoch + return losses.avg, top1.avg, top5.avg -def validate(model, val_loader, criterion, device, dataset_name="Validation"): +def validate(model, val_loader, criterion, device, return_preds_targets=False): """Evaluate model on validation set""" model.eval() @@ -439,45 +471,46 @@ def validate(model, val_loader, criterion, device, dataset_name="Validation"): top1 = AverageMeter() top5 = AverageMeter() - all_preds = [] - all_targets = [] + all_preds_list = [] + all_targets_list = [] - # Initialize tqdm progress bar - pbar = tqdm(val_loader, desc=f"[{dataset_name}]", leave=False, ncols=100, unit="batch") + pbar = tqdm(val_loader, desc="Validation", leave=False, ncols=100, unit="batch") # Changed leave to False with torch.no_grad(): for images, target in pbar: - # Move data to device images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) - # Forward pass output = model(images) loss = criterion(output, target) - # Measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) - # Collect predictions for Confusion Matrix - _, predicted_indices = torch.max(output.data, 1) - all_preds.extend(predicted_indices.cpu().numpy()) - all_targets.extend(target.cpu().numpy()) - - - # Update meters losses.update(loss.item(), images.size(0)) top1.update(acc1[0].item(), images.size(0)) top5.update(acc5[0].item(), images.size(0)) - # Update progress bar + if return_preds_targets: + preds = torch.argmax(output, dim=1) + all_preds_list.append(preds.cpu()) # Move to CPU immediately + all_targets_list.append(target.cpu()) # Move to CPU immediately + + pbar.set_postfix({ 'Loss': f"{losses.avg:.4f}", 'Acc@1': f"{top1.avg:.2f}%", 'Acc@5': f"{top5.avg:.2f}%" }) - print(f"* [{dataset_name}]: Acc@1 {top1.avg:.3f}% Acc@5 {top5.avg:.3f}% Loss {losses.avg:.4f}") - # Return averages AND the collected predictions/targets + all_preds = None + all_targets = None + if return_preds_targets and len(all_preds_list) > 0: + all_preds = torch.cat(all_preds_list) + all_targets = torch.cat(all_targets_list) + + + # No need to print here if called during training loop, will be printed in main loop + # If called standalone (e.g., for final CM), the calling function should print return losses.avg, top1.avg, top5.avg, all_preds, all_targets ##################################### @@ -488,287 +521,286 @@ def main(args): # Set random seed for reproducibility torch.manual_seed(args.seed) np.random.seed(args.seed) - random.seed(args.seed) # Add random seed for python's random module if torch.cuda.is_available(): - torch.cuda.manual_seed_all(args.seed) # Seed all GPUs - # These can sometimes slow down training or cause issues, enable if needed - # torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = True # Usually speeds up training - - # Create output directory - output_dir = os.path.join(args.output_dir, args.tag if args.tag else time.strftime("%Y%m%d-%H%M%S")) - os.makedirs(output_dir, exist_ok=True) - log_dir = os.path.join(args.log_dir, args.tag if args.tag else os.path.basename(output_dir)) # Log dir specific to this run + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for speed + + # --- Setup Directories --- + base_output_dir = args.output_dir + experiment_output_dir = os.path.join(base_output_dir, args.tag) if args.tag else base_output_dir + log_dir = args.log_dir # Use dedicated log dir + os.makedirs(experiment_output_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) + checkpoint_dir = os.path.join(experiment_output_dir, 'checkpoints') # Subdir for checkpoints + os.makedirs(checkpoint_dir, exist_ok=True) + + print(f"Output Directory: {experiment_output_dir}") + print(f"Log Directory: {log_dir}") + print(f"Checkpoint Directory: {checkpoint_dir}") - print(f"Output directory: {output_dir}") - print(f"Log directory: {log_dir}") - print(f"Using device: {args.device}") - # Create model - start with ImageNet classes (1000) or a base number - # The head will be replaced before each training stage. + # --- Create Model --- print("Initializing model...") - model = get_model(model_name='swin_t', efficient=args.efficient, num_classes=1000) # Start with 1000 classes - - # === Pretraining Stage: CIFAR === - if args.pretrain_cifar and not args.skip_pretrain: - print("\n--- Starting CIFAR-100 Pretraining Stage ---") - print("Creating dataloaders for cifar100...") - cifar_train_loader = get_cifar_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) - cifar_val_loader = get_cifar_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) - model = pretrain_on_dataset( - model, cifar_train_loader, cifar_val_loader, - num_classes=100, args=args, dataset_name='cifar100', - output_dir=output_dir, log_dir=log_dir - ) - print("--- CIFAR-100 Pretraining Stage Complete ---") - - - # === Pretraining Stage: Caltech === - if args.pretrain_caltech and not args.skip_pretrain: - print("\n--- Starting Caltech-256 Pretraining Stage ---") - print("Creating dataloaders for caltech256...") - caltech_train_loader = get_caltech_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) - caltech_val_loader = get_caltech_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) - model = pretrain_on_dataset( - model, caltech_train_loader, caltech_val_loader, - num_classes=257, args=args, dataset_name='caltech256', - output_dir=output_dir, log_dir=log_dir - ) - print("--- Caltech-256 Pretraining Stage Complete ---") + model = get_model(model_name='swin_t', efficient=args.efficient) + print(f"Model: swin_t (Efficient: {args.efficient})") + + + # --- Pretraining Stages --- + if not args.skip_pretrain: + if args.pretrain_cifar: + print("\n>>> Starting CIFAR-100 Pretraining Stage <<<") + cifar_train_loader = get_cifar_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) + cifar_val_loader = get_cifar_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) + model = pretrain_on_dataset(model, cifar_train_loader, cifar_val_loader, + num_classes=100, args=args, dataset_name='cifar100', + output_dir=experiment_output_dir, log_dir=log_dir) + print("\n>>> Finished CIFAR-100 Pretraining Stage <<<") + + + if args.pretrain_caltech: + print("\n>>> Starting Caltech-256 Pretraining Stage <<<") + # If CIFAR pretraining happened, the model already has a head for 100 classes. + # If not, it has the original 1000 class head. `pretrain_on_dataset` handles replacement. + caltech_train_loader = get_caltech_train_loader(args.batch_size, num_workers=args.workers, shuffle=True) + caltech_val_loader = get_caltech_val_loader(args.batch_size, num_workers=args.workers, shuffle=False) + model = pretrain_on_dataset(model, caltech_train_loader, caltech_val_loader, + num_classes=257, args=args, dataset_name='caltech256', + output_dir=experiment_output_dir, log_dir=log_dir) + print("\n>>> Finished Caltech-256 Pretraining Stage <<<") + else: + print("Skipping all pretraining stages as requested.") + - # === Final Training Stage (Tiny ImageNet) === - print("\n--- Starting Final Training Stage (Tiny ImageNet) ---") - final_num_classes = 200 # Tiny ImageNet has 200 classes - print(f"Creating dataloaders for Tiny ImageNet ({final_num_classes} classes)...") + # --- Final Training Stage: Tiny ImageNet --- + print("\n>>> Starting Final Training Stage: Tiny ImageNet <<<") + print("Creating dataloaders for Tiny ImageNet...") train_loader, val_loader, mixup_fn = get_loaders( batch_size=args.batch_size, num_workers=args.workers, - img_size=MODEL_CONFIG["img_size"], # Make sure this matches Swin-T expectation + img_size=MODEL_CONFIG["img_size"], use_mixup=args.mixup ) - # Replace head for the final dataset - model = replace_head(model, num_classes=final_num_classes) - model.to(args.device) - - # Handle DataParallel if multiple GPUs are available for the final stage - if torch.cuda.device_count() > 1 and not args.evaluate: # Don't wrap if only evaluating - print(f"Using {torch.cuda.device_count()} GPUs for final training stage.") - model = nn.DataParallel(model) - + print("Replacing model head for Tiny ImageNet (200 classes)...") + model = replace_head(model, num_classes=200) + model = model.to(args.device) # Print model information num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"Final Model - Number of trainable parameters: {num_params:,}") - - # Create optimizer - optimizer = create_optimizer(model.module if isinstance(model, nn.DataParallel) else model, args.lr, args.weight_decay) - - - # Create scheduler - steps_per_epoch_final = len(train_loader) - if steps_per_epoch_final == 0: - print("Error: Final train loader is empty!") - return + print(f"Number of trainable parameters (final stage): {num_params:,}") + # Create optimizer and scheduler for the final stage + optimizer = create_optimizer(model, args.lr, args.weight_decay) + steps_per_epoch_main = len(train_loader) scheduler = create_scheduler( optimizer, num_epochs=args.epochs, - steps_per_epoch=steps_per_epoch_final, - base_lr=args.lr, # Pass base LR + steps_per_epoch=steps_per_epoch_main, + base_lr=args.lr, warmup_epochs=args.warmup_epochs, min_lr=args.min_lr ) - # Create loss function - if args.mixup and mixup_fn is not None: - print("Using Mixup/CutMix augmentation with SoftTargetCrossEntropy loss.") + # Create loss function for the final stage + if args.mixup: criterion = SoftTargetCrossEntropy().to(args.device) + print("Using Mixup/Cutmix augmentation with SoftTargetCrossEntropy loss.") else: - print(f"Using Label Smoothing Cross Entropy loss with smoothing={args.label_smoothing}.") criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing).to(args.device) + print(f"Using Label Smoothing Cross Entropy loss (smoothing={args.label_smoothing}).") - # Optionally resume from checkpoint + # Optionally resume from a final stage checkpoint start_epoch = 0 best_acc = 0.0 if args.resume: - # Ensure the model head matches the checkpoint's expected classes before loading - # Note: This might require knowing the number of classes the checkpoint was saved with. - # If resuming a final stage checkpoint, the head should already be correct (200). - # If resuming a pretraining checkpoint, this logic needs adjustment. - # For simplicity, assume resume is for the final stage. - model_to_load = model.module if isinstance(model, nn.DataParallel) else model - start_epoch, best_acc = load_checkpoint(model_to_load, optimizer, scheduler, args.resume) - # Sync epoch for DataParallel case? Usually handled inside load_checkpoint if needed. - - # Evaluation only + resume_path = args.resume if os.path.isabs(args.resume) else os.path.join(checkpoint_dir, args.resume) + if os.path.isfile(resume_path): + print(f"Attempting to resume final stage training from: {resume_path}") + # Load optimizer and scheduler state when resuming main training + start_epoch, best_acc = load_checkpoint(model, optimizer, scheduler, resume_path, load_optimizer_scheduler=True) + print(f"Resumed final stage from epoch {start_epoch}. Previous best accuracy: {best_acc:.2f}%") + start_epoch = start_epoch # Checkpoint saves epoch+1, so start from the returned value + else: + print(f"Resume checkpoint not found at '{resume_path}'. Starting final training from scratch.") + + # Evaluation only mode for the final model if args.evaluate: - print("\n--- Running Evaluation Only ---") - if not args.resume: - print("Warning: Evaluating without loading a checkpoint (`--resume` not specified). Using initial model weights.") - _, val_acc1, _, final_preds, final_targets = validate(model, val_loader, criterion, args.device, dataset_name="Evaluation") - print(f"Evaluation Accuracy@1: {val_acc1:.3f}%") - # Generate plots for evaluation run - class_names = [str(i) for i in range(final_num_classes)] # Generic class names - eval_output_dir = os.path.join(output_dir, "evaluation") - os.makedirs(eval_output_dir, exist_ok=True) - plot_confusion_matrix(final_preds, final_targets, class_names, eval_output_dir, "tinyimagenet_eval") - # Cannot plot loss/acc curves without training history - return + print("--- Running Evaluation Only Mode ---") + eval_checkpoint_path = args.resume if args.resume else os.path.join(experiment_output_dir, 'model_best.pth') + if os.path.isfile(eval_checkpoint_path): + print(f"Loading model from: {eval_checkpoint_path} for evaluation...") + # Don't load optimizer/scheduler for evaluation + load_checkpoint(model, filename=eval_checkpoint_path, load_optimizer_scheduler=False) + print("Running validation...") + val_loss, val_acc1, val_acc5, all_preds, all_targets = validate( + model, val_loader, criterion, args.device, return_preds_targets=True + ) + print(f"\nEvaluation Results (Tiny ImageNet):") + print(f" Loss: {val_loss:.4f}") + print(f" Acc@1: {val_acc1:.2f}%") + print(f" Acc@5: {val_acc5:.2f}%") + + # Plot confusion matrix for evaluation + if all_preds is not None and all_targets is not None: + plot_confusion_matrix(all_preds.cpu().numpy(), all_targets.cpu().numpy(), 200, experiment_output_dir, "main_eval") + else: + print("Could not generate confusion matrix due to missing prediction data.") + + else: + print(f"Evaluation checkpoint '{eval_checkpoint_path}' not found. Cannot evaluate.") + return # Exit after evaluation - # Print training configuration - print(f"\nStarting final training for {args.epochs} epochs (from epoch {start_epoch})") - print(f"Batch size: {args.batch_size}") - print(f"Initial Learning rate: {args.lr}") - print(f"Minimum Learning rate: {args.min_lr}") - print(f"Weight decay: {args.weight_decay}") - print(f"Using mixup: {args.mixup and mixup_fn is not None}") - print(f"Label smoothing: {args.label_smoothing if not (args.mixup and mixup_fn is not None) else 'N/A (using SoftTargetCE)'}") - # --- Final Training loop --- - log_file_path_main = os.path.join(log_dir, "main_log.csv") + # --- Main Training Loop --- + print(f"\n--- Starting Final Training Loop (Tiny ImageNet) for {args.epochs - start_epoch} epochs ---") + print(f"Batch size: {args.batch_size}, Initial LR: {args.lr}, Weight Decay: {args.weight_decay}, Mixup: {args.mixup}") for epoch in range(start_epoch, args.epochs): epoch_start = time.time() + print(f"\n--- Tiny ImageNet Epoch {epoch+1}/{args.epochs} ---") - # Train for one epoch + # Train train_loss, train_acc1, train_acc5 = train_one_epoch( - model, train_loader, criterion, optimizer, scheduler, epoch, args.epochs, args.device, dataset_name="TinyImageNet Train", mixup_fn=mixup_fn) + model, train_loader, criterion, optimizer, scheduler, epoch, args.device, args, mixup_fn + ) - # Evaluate on validation set - val_loss, val_acc1, val_acc5, _, _ = validate( # Discard preds/targets during epoch validation - model, val_loader, criterion, args.device, dataset_name="TinyImageNet Val") + # Validate + val_loss, val_acc1, val_acc5, _, _ = validate( # Don't need preds/targets here + model, val_loader, criterion, args.device, return_preds_targets=False + ) - # Calculate epoch time epoch_time = time.time() - epoch_start - # Check if current model is the best + # Check if current epoch is best is_best = val_acc1 > best_acc if is_best: + old_best = best_acc best_acc = val_acc1 - print(f"** New Best Val Acc@1: {best_acc:.3f}% **") - - # Log metrics to CSV - # Handle None for train_acc if mixup was used - log_metrics("main", epoch, - train_loss, train_acc1, train_acc5, - val_loss, val_acc1, val_acc5, - log_dir=log_dir) - - - # Print epoch summary - print(f"--- Epoch {epoch+1}/{args.epochs} Summary ---") - print(f" Time: {epoch_time:.2f}s") - train_acc1_str = f"{train_acc1:.2f}%" if train_acc1 is not None else "N/A (Mixup)" - train_acc5_str = f"{train_acc5:.2f}%" if train_acc5 is not None else "N/A (Mixup)" - print(f" Train: Loss {train_loss:.4f}, Acc@1 {train_acc1_str}, Acc@5 {train_acc5_str}") - print(f" Valid: Loss {val_loss:.4f}, Acc@1 {val_acc1:.2f}%, Acc@5 {val_acc5:.2f}%") - print(f" Best Valid Acc@1 so far: {best_acc:.2f}%") - print("-" * (len(f"--- Epoch {epoch+1}/{args.epochs} Summary ---"))) # Divider + print(f"*** New Best Accuracy: {best_acc:.2f}% (Improved from {old_best:.2f}%) ***") + else: + print(f"Validation Acc@1: {val_acc1:.2f}% (Best: {best_acc:.2f}%)") - # Retrieve the actual model state dict, handling DataParallel - model_state_to_save = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict() - + # Log metrics to CSV for the main training stage + log_metrics("main", epoch, train_loss, train_acc1, train_acc5, val_loss, val_acc1, val_acc5, log_dir=log_dir) # Save checkpoint periodically if (epoch + 1) % args.save_interval == 0: - save_checkpoint( - model_state_to_save, optimizer, scheduler, epoch + 1, val_acc1, - os.path.join(output_dir, f'checkpoint_epoch{epoch+1}.pth') - ) + save_checkpoint( + model, optimizer, scheduler, epoch + 1, val_acc1, + os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch+1}.pth') + ) - # Always save the best model based on validation accuracy + # Always save the best model if is_best: - save_checkpoint( - model_state_to_save, optimizer, scheduler, epoch + 1, best_acc, # Save best_acc here - os.path.join(output_dir, 'model_best.pth') - ) + save_checkpoint( + model, optimizer, scheduler, epoch + 1, best_acc, # Save best_acc here + os.path.join(experiment_output_dir, 'model_best.pth') # Save best model in parent dir + ) - print(f"\nTraining complete. Best validation accuracy: {best_acc:.2f}%") + # Print epoch summary + print(f"Epoch {epoch+1} Summary | Time: {epoch_time:.2f}s | LR: {scheduler.get_last_lr()[0]:.6f}") + print(f" Train -> Loss: {train_loss:.4f}, Acc@1: {train_acc1:.2f}%" if not args.mixup else f" Train -> Loss: {train_loss:.4f}, Acc@1: N/A (Mixup)") + print(f" Valid -> Loss: {val_loss:.4f}, Acc@1: {val_acc1:.2f}%, Acc@5: {val_acc5:.2f}%") - # --- Generate Final Plots after Training --- - print(f"\n--- Generating plots for final Tiny ImageNet training ---") - # Plot Loss/Accuracy Curves from the main log file - plot_metrics(log_file_path_main, output_dir, "tinyimagenet_main") - # Load the best model checkpoint for final validation and CM - best_model_path = os.path.join(output_dir, 'model_best.pth') + print(f"\n--- Finished Final Training Stage (Tiny ImageNet) ---") + print(f"Best validation accuracy achieved: {best_acc:.2f}%") + + # --- Final Plotting and Confusion Matrix for Main Training --- + # Plot loss/accuracy curves for the main training + main_log_file = os.path.join(log_dir, "main_log.csv") + plot_metrics(main_log_file, experiment_output_dir, "main") + + # Load the *best* model for the final confusion matrix + best_model_path = os.path.join(experiment_output_dir, 'model_best.pth') if os.path.exists(best_model_path): - print(f"Loading best model from {best_model_path} for final validation...") - # Re-create model instance and load state dict (necessary if DataParallel was used during training) - final_model = get_model(model_name='swin_t', efficient=args.efficient, num_classes=final_num_classes) - load_checkpoint(final_model, filename=best_model_path) # Load only model weights + print(f"Loading best model from {best_model_path} for final confusion matrix...") + # Create a fresh instance or reload into the current one + # Re-create model to ensure clean state if needed, though loading state_dict should be fine + final_model = get_model(model_name='swin_t', efficient=args.efficient) # Head already replaced + # final_model = replace_head(final_model, num_classes=200) + + load_checkpoint(final_model, filename=best_model_path, load_optimizer_scheduler=False) final_model.to(args.device) - if torch.cuda.device_count() > 1: # Apply DataParallel if needed for validation - final_model = nn.DataParallel(final_model) - else: - print("Best model checkpoint not found. Using model from last epoch for final validation.") - final_model = model # Use the model from the last training epoch + final_model.eval() - # Generate Confusion Matrix using the best (or last) model - print("Running final validation on Tiny ImageNet for Confusion Matrix...") - # Use the correct criterion for validation - final_val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing).to(args.device) # Use standard CE for final validation CM - - _, _, _, final_preds, final_targets = validate( - final_model, val_loader, final_val_criterion, args.device, dataset_name="Final Best Model Val" + print("Running validation on best model for final confusion matrix...") + _, _, _, all_preds, all_targets = validate( + final_model, val_loader, criterion, args.device, return_preds_targets=True ) - class_names = [str(i) for i in range(final_num_classes)] # Generic class names - plot_confusion_matrix(final_preds, final_targets, class_names, output_dir, "tinyimagenet_main_best") - print("--- Final Training Stage Complete ---") + if all_preds is not None and all_targets is not None: + plot_confusion_matrix(all_preds.cpu().numpy(), all_targets.cpu().numpy(), 200, experiment_output_dir, "main_best_model") + else: + print("Could not generate final confusion matrix due to missing prediction data.") + + else: + print(f"Best model checkpoint '{best_model_path}' not found. Cannot generate confusion matrix for the best model.") + + print("\n>>> All Training Stages Complete <<<") def parse_args(): - parser = argparse.ArgumentParser(description='Swin Transformer Training with Pretraining and Plotting') + parser = argparse.ArgumentParser(description='Swin Transformer Training with Pretraining Options') - # Model parameters - parser.add_argument('--efficient', action='store_true', help='Use efficient model variant') + # --- Model --- + parser.add_argument('--model-name', type=str, default='swin_t', help='Name of the model architecture (e.g., swin_t)') + parser.add_argument('--efficient', action='store_true', help='Use efficient model variant (if available)') + # parser.add_argument('--no-pretrained', action='store_true', help='Do not use ImageNet pretrained weights initially') - # pre-training... + # --- Pre-training --- parser.add_argument('--pretrain-cifar', action='store_true', help='Pretrain on CIFAR-100 first') - parser.add_argument('--pretrain-caltech', action='store_true', help='Pretrain on Caltech-256 (after CIFAR if specified)') - parser.add_argument('--skip-pretrain', action='store_true', help='Skip all pretraining stages') - - # Training parameters - parser.add_argument('--batch-size', type=int, default=TRAIN_CONFIG['batch_size'], metavar='N', help=f'Input batch size (default: {TRAIN_CONFIG["batch_size"]})') - parser.add_argument('--epochs', type=int, default=TRAIN_CONFIG['epochs'], metavar='N', help=f'Number of epochs to train (default: {TRAIN_CONFIG["epochs"]})') - parser.add_argument('--lr', '--learning-rate', type=float, default=TRAIN_CONFIG['learning_rate'], metavar='LR', help=f'Initial learning rate (default: {TRAIN_CONFIG["learning_rate"]})') - parser.add_argument('--min-lr', type=float, default=TRAIN_CONFIG['min_lr'], metavar='MINLR', help=f'Minimum learning rate for cosine scheduler (default: {TRAIN_CONFIG["min_lr"]})') - parser.add_argument('--warmup-epochs', type=int, default=TRAIN_CONFIG['warmup_epochs'], metavar='N', help=f'Number of warmup epochs (default: {TRAIN_CONFIG["warmup_epochs"]})') - parser.add_argument('--weight-decay', type=float, default=TRAIN_CONFIG['weight_decay'], metavar='WD', help=f'Weight decay (default: {TRAIN_CONFIG["weight_decay"]})') - parser.add_argument('--label-smoothing', type=float, default=TRAIN_CONFIG['label_smoothing'], metavar='LS', help=f'Label smoothing factor (default: {TRAIN_CONFIG["label_smoothing"]})') - parser.add_argument('--device', default=TRAIN_CONFIG['device'], help='Device to use (cuda or cpu)') - parser.add_argument('--mixup', action='store_true', default=False, help='Use mixup and cutmix augmentation (requires timm mixup implementation in get_loaders)') - - # Data loading - parser.add_argument('--workers', type=int, default=4, metavar='N', help='Number of data loading workers (default: 4)') - - # Checkpointing - parser.add_argument('--output-dir', default=TRAIN_CONFIG['output_dir'], help='Path to save output (checkpoints, plots)') - parser.add_argument('--tag', default='', help='Experiment tag to append to output/log directories') - parser.add_argument('--save-interval', type=int, default=TRAIN_CONFIG['save_interval'], metavar='N', help='Save checkpoint every N epochs (default: 10)') - parser.add_argument('--resume', default='', type=str, metavar='PATH', help='Resume final training stage from checkpoint path') - - # logs - parser.add_argument('--log-dir', default='logs', type=str, help='Directory to save all training logs (.csv files)') - - # Misc - parser.add_argument('--seed', type=int, default=42, metavar='S', help='Random seed (default: 42)') - parser.add_argument('--evaluate', action='store_true', help='Perform evaluation only (requires --resume)') - # parser.add_argument('--no-plots', action='store_true', help='Disable generating plots') # Optional: Add if you want to disable plots + parser.add_argument('--pretrain-caltech', action='store_true', help='Pretrain on Caltech-256 (after CIFAR if specified, otherwise from ImageNet)') + parser.add_argument('--skip-pretrain', action='store_true', help='Skip all pretraining stages and train directly on Tiny ImageNet') + # parser.add_argument('--pretrain-epochs', type=int, default=30, help='Number of epochs for each pretraining stage (if different from main epochs)') # Optional: Separate epoch control + + # --- Main Training --- + parser.add_argument('--batch-size', type=int, default=TRAIN_CONFIG['batch_size'], help='Input batch size for training') + parser.add_argument('--epochs', type=int, default=TRAIN_CONFIG['epochs'], help='Number of epochs to train') + parser.add_argument('--lr', '--learning-rate', type=float, default=TRAIN_CONFIG['learning_rate'], help='Initial learning rate') + parser.add_argument('--min-lr', type=float, default=TRAIN_CONFIG['min_lr'], help='Minimum learning rate for scheduler') + parser.add_argument('--warmup-epochs', type=int, default=TRAIN_CONFIG['warmup_epochs'], help='Number of warmup epochs') + parser.add_argument('--weight-decay', type=float, default=TRAIN_CONFIG['weight_decay'], help='Optimizer weight decay') + parser.add_argument('--label-smoothing', type=float, default=TRAIN_CONFIG['label_smoothing'], help='Label smoothing factor (if not using mixup)') + parser.add_argument('--mixup', action='store_true', help='Use mixup and cutmix augmentation (disables label smoothing)') + + # --- Data & Device --- + parser.add_argument('--img-size', type=int, default=MODEL_CONFIG['img_size'], help='Input image size') # Make img_size configurable if needed + parser.add_argument('--workers', type=int, default=4, help='Number of data loading workers') + parser.add_argument('--device', default=TRAIN_CONFIG['device'], help='Device to use (e.g., "cuda", "cpu")') + + # --- Checkpointing & Logging --- + parser.add_argument('--output-dir', default=TRAIN_CONFIG['output_dir'], help='Base directory to save checkpoints and logs') + parser.add_argument('--log-dir', default=TRAIN_CONFIG['log_dir'], help='Directory within output-dir to save CSV logs and plots') + parser.add_argument('--tag', default='', type=str, help='Optional tag for experiment directory name') + parser.add_argument('--save-interval', type=int, default=TRAIN_CONFIG['save_interval'], help='Save checkpoint every N epochs during main training') + parser.add_argument('--resume', default='', type=str, metavar='PATH', help='Path to latest checkpoint to resume main training (or for evaluation)') + + # --- Misc --- + parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility') + parser.add_argument('--evaluate', action='store_true', help='Perform evaluation only on the validation set (requires --resume or finds model_best.pth)') args = parser.parse_args() - # Set device explicitly + # Set device based on argument args.device = torch.device(args.device if torch.cuda.is_available() else "cpu") + # If mixup is used, disable label smoothing effect by setting it to 0 + if args.mixup: + args.label_smoothing = 0.0 + print("Mixup enabled, label smoothing set to 0.") + + # Ensure log_dir is inside output_dir unless absolute path is given + if not os.path.isabs(args.log_dir): + args.log_dir = os.path.join(args.output_dir, args.tag if args.tag else '', args.log_dir) + + return args + if __name__ == '__main__': - # Need 'random' for seeding - import random args = parse_args() + # Ensure necessary directories exist based on final paths + os.makedirs(args.log_dir, exist_ok=True) main(args)