Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add script for experiments with ResNet50 #14

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 346 additions & 0 deletions examples/cdan_sdat_tllib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
# Credits: https://github.com/thuml/Transfer-Learning-Library
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import os
import wandb

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F

sys.path.append('../')
sys.path.append('../../../')
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.alignment.cdan import ConditionalDomainAdversarialLoss, ImageClassifier
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
from tllib.utils.sam import SAM

sys.path.append('.')
import utils


def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)

if args.log_results:
wandb.init(project="DA", entity="SDAT", name=args.log_name)
wandb.config.update(args)
print(args)

if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')

cudnn.benchmark = True
device = args.device

# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)

train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source,
args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(
test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)

# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
print(backbone)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
classifier_feature_dim = classifier.features_dim

if args.randomized:
domain_discri = DomainDiscriminator(
args.randomized_dim, hidden_size=1024).to(device)
else:
domain_discri = DomainDiscriminator(
classifier_feature_dim * num_classes, hidden_size=1024).to(device)

# define optimizer and lr scheduler
base_optimizer = torch.optim.SGD
ad_optimizer = SGD(domain_discri.get_parameters(
), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
optimizer = SAM(classifier.get_parameters(), base_optimizer, rho=args.rho, adaptive=False,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr *
(1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
lr_scheduler_ad = LambdaLR(
ad_optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))

# define loss function
domain_adv = ConditionalDomainAdversarialLoss(
domain_discri, entropy_conditioning=args.entropy,
num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized,
randomized_dim=args.randomized_dim
).to(device)

# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(
logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)

# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(
classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(
train_source_loader, feature_extractor, device)
target_feature = collect_feature(
train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(
source_feature, target_feature, device)
print("A-distance =", A_distance)
return

if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return

# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print("lr_bbone:", lr_scheduler.get_last_lr()[0])
print("lr_btlnck:", lr_scheduler.get_last_lr()[1])
if args.log_results:
wandb.log({"lr_bbone": lr_scheduler.get_last_lr()[0],
"lr_btlnck": lr_scheduler.get_last_lr()[1]})

# train for one epoch
train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, ad_optimizer,
lr_scheduler, lr_scheduler_ad, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
if args.log_results:
wandb.log({'epoch': epoch, 'val_acc': acc1})

# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(),
logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'),
logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)

print("best_acc1 = {:3.1f}".format(best_acc1))

# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
if args.log_results:
wandb.log({'epoch': epoch, 'test_acc': acc1})

logger.close()


def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
domain_adv: ConditionalDomainAdversarialLoss, optimizer, ad_optimizer,
lr_scheduler: LambdaLR, lr_scheduler_ad, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
trans_losses = AverageMeter('Trans Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, trans_losses, cls_accs, domain_accs],
prefix="Epoch: [{}]".format(epoch))

device = args.device
# switch to train mode
model.train()
domain_adv.train()

end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
x_t, _ = next(train_target_iter)[:2]

x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)

# measure data loading time
data_time.update(time.time() - end)
optimizer.zero_grad()
ad_optimizer.zero_grad()

# compute task loss for first step
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s)
loss = cls_loss
loss.backward()

# Calculate ϵ̂ (w) and add it to the weights
optimizer.first_step(zero_grad=True)

# Calculate task loss and domain loss
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)

cls_loss = F.cross_entropy(y_s, labels_s)
transfer_loss = domain_adv(y_s, f_s, y_t, f_t)
domain_acc = domain_adv.domain_discriminator_accuracy
loss = cls_loss + transfer_loss * args.trade_off

cls_acc = accuracy(y_s, labels_s)[0]
if args.log_results:
wandb.log({'iteration': epoch*args.iters_per_epoch + i, 'loss': loss, 'cls_loss': cls_loss,
'transfer_loss': transfer_loss, 'domain_acc': domain_acc})

losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc, x_s.size(0))
domain_accs.update(domain_acc, x_s.size(0))
trans_losses.update(transfer_loss.item(), x_s.size(0))

loss.backward()
# Update parameters of domain classifier
ad_optimizer.step()
# Update parameters (Sharpness-Aware update)
optimizer.second_step(zero_grad=True)
lr_scheduler.step()
lr_scheduler_ad.step()

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if i % args.print_freq == 0:
progress.display(i)


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='CDAN with SDAT for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true',
help='whether train from scratch.')
parser.add_argument('-r', '--randomized', action='store_true',
help='using randomized multi-linear-map (default: False)')
parser.add_argument('-rd', '--randomized-dim', default=1024, type=int,
help='randomized dimension when using randomized multi-linear-map (default: 1024)')
parser.add_argument('--entropy', default=False,
action='store_true', help='use entropy conditioning')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.001,
type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75,
type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9,
type=float, metavar='M', help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='cdan',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
parser.add_argument('--log_results', action='store_true',
help="To log results in wandb")
parser.add_argument('--gpu', type=str, default="0", help="GPU ID")
parser.add_argument('--log_name', type=str,
default="log", help="log name for wandb")
parser.add_argument('--rho', type=float, default=0.05, help="GPU ID")
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.device = device
main(args)
Loading