import tensorflow as tf
import array
import gzip
import random
from tensorflow.keras import Model
from collections import defaultdict
def parse(path):
g = gzip.open(path, 'r')
for l in g:
yield eval(l)
userIDs = {}
itemIDs = {}
interactions = []
# Could adapt to any dataset, this one is from
# https://sites.google.com/eng.ucsd.edu/ucsdbookgraph/home
for d in parse("goodreads_reviews_comics_graphic.json.gz"):
u = d['user_id']
i = d['book_id']
r = d['rating']
if not u in userIDs: userIDs[u] = len(userIDs)
if not i in itemIDs: itemIDs[i] = len(itemIDs)
interactions.append((u,i,r))
len(interactions)
# mean rating, just for initialization
mu = sum([r for _,_,r in interactions]) / len(interactions)
# Gradient descent optimizer, experiment with learning rate
optimizer = tf.keras.optimizers.Adam(0.001)
class LatentFactorModel(tf.keras.Model):
def __init__(self, mu, K, lamb):
super(LatentFactorModel, self).__init__()
# Initialize to average
self.alpha = tf.Variable(mu)
# Initialize to small random values
self.betaU = tf.Variable(tf.random.normal([len(userIDs)],stddev=0.001))
self.betaI = tf.Variable(tf.random.normal([len(itemIDs)],stddev=0.001))
self.gammaU = tf.Variable(tf.random.normal([len(userIDs),K],stddev=0.001))
self.gammaI = tf.Variable(tf.random.normal([len(itemIDs),K],stddev=0.001))
self.lamb = lamb
# Prediction for a single instance (useful for evaluation)
def predict(self, u, i):
p = self.alpha + self.betaU[u] + self.betaI[i] +\
tf.tensordot(self.gammaU[u], self.gammaI[i], 1)
return p
# Regularizer
def reg(self):
return self.lamb * tf.reduce_sum(self.betaU**2) +\
tf.reduce_sum(self.betaI**2) +\
tf.reduce_sum(self.gammaU**2) +\
tf.reduce_sum(self.gammaI**2)
# Prediction for a sample of instances
def predictSample(self, sampleU, sampleI):
u = tf.convert_to_tensor(sampleU, dtype=tf.int32)
i = tf.convert_to_tensor(sampleI, dtype=tf.int32)
beta_u = tf.nn.embedding_lookup(self.betaU, u)
beta_i = tf.nn.embedding_lookup(self.betaI, i)
gamma_u = tf.nn.embedding_lookup(self.gammaU, u)
gamma_i = tf.nn.embedding_lookup(self.gammaI, i)
pred = self.alpha + beta_u + beta_i +\
tf.reduce_sum(tf.multiply(gamma_u, gamma_i), 1)
return pred
# Loss
def call(self, sampleU, sampleI, sampleR):
pred = self.predictSample(sampleU, sampleI)
r = tf.convert_to_tensor(sampleR, dtype=tf.float32)
return tf.nn.l2_loss(pred - r) / len(sampleR)
# Experiment with number of factors and regularization rate
model = LatentFactorModel(mu, 5, 0.00001)
def trainingStep(interactions):
Nsamples = 50000
with tf.GradientTape() as tape:
sampleU, sampleI, sampleR = [], [], []
for _ in range(Nsamples):
u,i,r = random.choice(interactions)
sampleU.append(userIDs[u])
sampleI.append(itemIDs[i])
sampleR.append(r)
loss = model(sampleU,sampleI,sampleR)
loss += model.reg()
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients((grad, var) for
(grad, var) in zip(gradients, model.trainable_variables)
if grad is not None)
return loss.numpy()
# 10 iterations of gradient descent
for i in range(10):
obj = trainingStep(interactions)
print("iteration " + str(i) + ", objective = " + str(obj))
# Prediction (e.g.):
interactions[0]
model.predict(userIDs['dc3763cdb9b2cae805882878eebb6a32'], itemIDs['18471619']).numpy()
itemsPerUser = defaultdict(set)
for u,i,_ in interactions:
itemsPerUser[u].add(i)
items = list(itemIDs.keys())
# Experiment with learning rate
optimizer = tf.keras.optimizers.Adam(0.001)
class BPRbatch(tf.keras.Model):
def __init__(self, K, lamb):
super(BPRbatch, self).__init__()
# Initialize variables
self.betaI = tf.Variable(tf.random.normal([len(itemIDs)],stddev=0.001))
self.gammaU = tf.Variable(tf.random.normal([len(userIDs),K],stddev=0.001))
self.gammaI = tf.Variable(tf.random.normal([len(itemIDs),K],stddev=0.001))
# Regularization coefficient
self.lamb = lamb
# Prediction for a single instance
def predict(self, u, i):
p = self.betaI[i] + tf.tensordot(self.gammaU[u], self.gammaI[i], 1)
return p
# Regularizer
def reg(self):
return self.lamb * tf.nn.l2_loss(self.betaI) +\
tf.nn.l2_loss(self.gammaU) +\
tf.nn.l2_loss(self.gammaI)
def score(self, sampleU, sampleI):
u = tf.convert_to_tensor(sampleU, dtype=tf.int32)
i = tf.convert_to_tensor(sampleI, dtype=tf.int32)
beta_i = tf.nn.embedding_lookup(self.betaI, i)
gamma_u = tf.nn.embedding_lookup(self.gammaU, u)
gamma_i = tf.nn.embedding_lookup(self.gammaI, i)
x_ui = beta_i + tf.reduce_sum(tf.multiply(gamma_u, gamma_i), 1)
return x_ui
def call(self, sampleU, sampleI, sampleJ):
x_ui = self.score(sampleU, sampleI)
x_uj = self.score(sampleU, sampleJ)
return -tf.reduce_mean(tf.math.log(tf.math.sigmoid(x_ui - x_uj)))
model = BPRbatch(5, 0.00001)
def trainingStep(interactions):
Nsamples = 50000
with tf.GradientTape() as tape:
sampleU, sampleI, sampleJ = [], [], []
for _ in range(Nsamples):
u,i,_ = random.choice(interactions) # positive sample
j = random.choice(items) # negative sample
while j in itemsPerUser[u]:
j = random.choice(items)
sampleU.append(userIDs[u])
sampleI.append(itemIDs[i])
sampleJ.append(itemIDs[j])
loss = model(sampleU,sampleI,sampleJ)
loss += model.reg()
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients((grad, var) for
(grad, var) in zip(gradients, model.trainable_variables)
if grad is not None)
return loss.numpy()
# Run training iterations (probably want more than 10...)
for i in range(10):
obj = trainingStep(interactions)
print("iteration " + str(i) + ", objective = " + str(obj))
# In this case just a score, rather than a prediction of a rating
model.predict(userIDs['dc3763cdb9b2cae805882878eebb6a32'], itemIDs['18471619']).numpy()