-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathlearner_kernel.py
More file actions
59 lines (55 loc) · 2.83 KB
/
learner_kernel.py
File metadata and controls
59 lines (55 loc) · 2.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import scipy as sp
from sklearn.metrics.pairwise import rbf_kernel
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
class KernelTS:
def __init__(self, dim, lamdba=1, nu=1, style='ts'):
self.dim = dim
self.lamdba = lamdba
self.nu = nu
self.x_t = None
self.r_t = None
self.history_len = 0
self.scale = self.lamdba * self.nu
self.style = style
self.U_t = None
self.K_t = None
def select(self, context):
a, f = context.shape
if self.history_len == 0:
mu_t = torch.zeros((a,), device=torch.device('cuda'))
sigma_t = self.scale * torch.ones((a,), device=torch.device('cuda'))
else:
c_t = torch.from_numpy(context).float().cuda()
delta_t = c_t.reshape((a, 1, -1)) - self.x_t.reshape((1, self.history_len, -1))
k_t = torch.exp(- delta_t.norm(dim=2))
# print(k_t)
mu_t = k_t.matmul(self.U_t.matmul(self.r_t))
sigma_t = self.scale * (torch.ones((a,), device=torch.device('cuda')) - torch.diag(k_t.matmul(self.U_t.matmul(k_t.T))))
if self.style == 'ts':
r = MultivariateNormal(mu_t, torch.diag(sigma_t)).sample()
elif self.style == 'ucb':
r = mu_t + torch.sqrt(sigma_t)
return torch.argmax(r), 1, torch.mean(sigma_t), torch.max(r)
def train(self, context, reward):
f = context.shape[0]
if self.history_len < 1000:
if self.x_t is None:
self.x_t = torch.from_numpy(context).float().cuda().reshape((1, -1))
self.r_t = torch.tensor(reward, device=torch.device('cuda'), dtype=torch.float).reshape((-1,))
self.K_t = torch.tensor(1, device=torch.device('cuda'), dtype=torch.float).reshape((1, 1))
else:
c_t = torch.from_numpy(context).float().cuda().reshape((1, -1))
r_t = torch.tensor(reward, device=torch.device('cuda'), dtype=torch.float).reshape((-1,))
delta_t = c_t.reshape((1, 1, -1)) - self.x_t.reshape((1, self.history_len, -1))
self.x_t = torch.cat((self.x_t, c_t), dim=0)
self.r_t = torch.cat((self.r_t, r_t), dim=0)
# print(self.x_t.shape, self.r_t.shape, self.K_t.shape)
k_t = torch.exp(- delta_t.norm(dim=2)).reshape((-1, 1))
a = torch.cat((k_t.T, torch.ones((1, 1), dtype=torch.float, device=torch.device('cuda'))), dim=1)
b = torch.cat((self.K_t, k_t), dim=1)
self.K_t = torch.cat((b, a) , dim=0)
self.history_len += 1
self.U_t = torch.inverse(self.K_t + self.lamdba * torch.eye(self.history_len, device=torch.device('cuda')))
return 0