-
Notifications
You must be signed in to change notification settings - Fork 51
Expand file tree
/
Copy pathgradient_reversal_example.py
More file actions
48 lines (39 loc) · 1.54 KB
/
gradient_reversal_example.py
File metadata and controls
48 lines (39 loc) · 1.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch.nn as nn
from torch.autograd import Function
'''
Very easy template to start for developing your AlexNet with DANN
Has not been tested, might contain incompatibilities with most recent versions of PyTorch (you should address this)
However, the logic is consistent
'''
class ReverseLayerF(Function):
# Forwards identity
# Sends backward reversed gradients
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
class RandomNetworkWithReverseGrad(nn.Module):
def __init__(self, **kwargs):
super(RandomNetworkWithReverseGrad, self).__init__()
self.features = nn.Sequential(...)
self.classifier = nn.Sequential(...)
self.dann_classifier = nn.Sequential(...)
def forward(self, x, alpha=None):
features = self.features
# Flatten the features:
features = features.view(features.size(0), -1)
# If we pass alpha, we can assume we are training the discriminator
if alpha is not None:
# gradient reversal layer (backward gradients will be reversed)
reverse_feature = ReverseLayerF.apply(features, alpha)
discriminator_output = ...
return discriminator_output
# If we don't pass alpha, we assume we are training with supervision
else:
# do something else
class_outputs = ...
return class_outputs