import gzip
import random
import scipy
import tensorflow as tf
from collections import defaultdict
from implicit import bpr
from surprise import SVD, Reader, Dataset
from surprise.model_selection import train_test_split
Data is available at http://cseweb.ucsd.edu/~jmcauley/pml/data/. Download and save to your own directory
dataDir = "/home/jmcauley/pml_data/"
Using the library's inbuilt data reader, extract tsv-formatted data
reader = Reader(line_format='user item rating', sep='\t')
data = Dataset.load_from_file(dataDir + "goodreads_fantasy.tsv", reader=reader)
Standard latent-factor model
model = SVD()
Inbuilt functions to split into training and test fractions
trainset, testset = train_test_split(data, test_size=.25)
Fit the model and extract predictions
model.fit(trainset)
predictions = model.test(testset)
Estimate for a single (test) rating
predictions[0].est
4.001383054002914
MSE for model predictions (test set)
sse = 0
for p in predictions:
sse += (p.r_ui - p.est)**2
print(sse / len(predictions))
1.1877783834090485
def parseData(fname):
for l in gzip.open(fname):
d = eval(l)
del d['review_text'] # Discard the reviews, to save memory when we don't use them
yield d
Full dataset of Goodreads fantasy reviews (fairly memory-hungry, could be replaced by something smaller)
data = list(parseData(dataDir + "goodreads_reviews_fantasy_paranormal.json.gz"))
random.shuffle(data)
Example from the dataset
data[0]
{'user_id': '29b4208f9c7fb6d173dfa7393d465d1d', 'book_id': '8087526', 'review_id': 'b2cb7fc6ccba961a3042b090f84a04cf', 'rating': 4, 'date_added': 'Sun Dec 28 10:56:38 -0800 2014', 'date_updated': 'Mon Feb 29 11:33:35 -0800 2016', 'read_at': 'Thu Jan 15 00:00:00 -0800 2015', 'started_at': 'Sun Dec 28 00:00:00 -0800 2014', 'n_votes': 2, 'n_comments': 2}
Build a few utility data structures. Since we'll be converting the data to a sparse interaction matrix, the main structure here is to assign each user/item to an ID from 0 to nUsers/nItems.
userIDs,itemIDs = {},{}
for d in data:
u,i = d['user_id'],d['book_id']
if not u in userIDs: userIDs[u] = len(userIDs)
if not i in itemIDs: itemIDs[i] = len(itemIDs)
nUsers,nItems = len(userIDs),len(itemIDs)
nUsers,nItems
(256088, 258212)
Convert dataset to sparse matrix. Only storing positive feedback instances (i.e., rated items).
Xui = scipy.sparse.lil_matrix((nUsers, nItems))
for d in data:
Xui[userIDs[d['user_id']],itemIDs[d['book_id']]] = 1
Xui_csr = scipy.sparse.csr_matrix(Xui)
Bayesian Personalized Ranking model with 5 latent factors
model = bpr.BayesianPersonalizedRanking(factors = 5)
Fit the model
model.fit(Xui_csr)
Get recommendations for a particular user (the first one) and to get items related to (similar latent factors) to a particular item
recommended = model.recommend(0, Xui_csr[0])
related = model.similar_items(0)
related
(array([ 0, 50907, 4837, 31854, 13388, 57087, 66904, 21845, 28072, 23460], dtype=int32), array([1. , 0.9951341 , 0.9949262 , 0.99489623, 0.9937351 , 0.9932829 , 0.99264044, 0.9920761 , 0.9916638 , 0.99135584], dtype=float32))
Extract user and item factors
itemFactors = model.item_factors
userFactors = model.user_factors
itemFactors[0]
array([ 0.0347619 , -0.50331426, -1.475029 , 1.2996856 , -0.8502873 , -0.65859324], dtype=float32)
def parse(path):
g = gzip.open(path, 'r')
for l in g:
yield eval(l)
Goodreads comic book data
userIDs = {}
itemIDs = {}
interactions = []
for d in parse(dataDir + "goodreads_reviews_comics_graphic.json.gz"):
u = d['user_id']
i = d['book_id']
r = d['rating']
if not u in userIDs: userIDs[u] = len(userIDs)
if not i in itemIDs: itemIDs[i] = len(itemIDs)
interactions.append((u,i,r))
random.shuffle(interactions)
len(interactions)
542338
Split into train and test sets
nTrain = int(len(interactions) * 0.9)
nTest = len(interactions) - nTrain
interactionsTrain = interactions[:nTrain]
interactionsTest = interactions[nTrain:]
itemsPerUser = defaultdict(list)
usersPerItem = defaultdict(list)
for u,i,r in interactionsTrain:
itemsPerUser[u].append(i)
usersPerItem[i].append(u)
Mean rating, just for initialization
mu = sum([r for _,_,r in interactionsTrain]) / len(interactionsTrain)
Gradient descent optimizer, could experiment with learning rate
optimizer = tf.keras.optimizers.Adam(0.1)
Latent factor model tensorflow class
class LatentFactorModel(tf.keras.Model):
def __init__(self, mu, K, lamb):
super(LatentFactorModel, self).__init__()
# Initialize to average
self.alpha = tf.Variable(mu)
# Initialize to small random values
self.betaU = tf.Variable(tf.random.normal([len(userIDs)],stddev=0.001))
self.betaI = tf.Variable(tf.random.normal([len(itemIDs)],stddev=0.001))
self.gammaU = tf.Variable(tf.random.normal([len(userIDs),K],stddev=0.001))
self.gammaI = tf.Variable(tf.random.normal([len(itemIDs),K],stddev=0.001))
self.lamb = lamb
# Prediction for a single instance (useful for evaluation)
def predict(self, u, i):
p = self.alpha + self.betaU[u] + self.betaI[i] +\
tf.tensordot(self.gammaU[u], self.gammaI[i], 1)
return p
# Regularizer
def reg(self):
return self.lamb * (tf.reduce_sum(self.betaU**2) +\
tf.reduce_sum(self.betaI**2) +\
tf.reduce_sum(self.gammaU**2) +\
tf.reduce_sum(self.gammaI**2))
# Prediction for a sample of instances
def predictSample(self, sampleU, sampleI):
u = tf.convert_to_tensor(sampleU, dtype=tf.int32)
i = tf.convert_to_tensor(sampleI, dtype=tf.int32)
beta_u = tf.nn.embedding_lookup(self.betaU, u)
beta_i = tf.nn.embedding_lookup(self.betaI, i)
gamma_u = tf.nn.embedding_lookup(self.gammaU, u)
gamma_i = tf.nn.embedding_lookup(self.gammaI, i)
pred = self.alpha + beta_u + beta_i +\
tf.reduce_sum(tf.multiply(gamma_u, gamma_i), 1)
return pred
# Loss
def call(self, sampleU, sampleI, sampleR):
pred = self.predictSample(sampleU, sampleI)
r = tf.convert_to_tensor(sampleR, dtype=tf.float32)
return tf.nn.l2_loss(pred - r) / len(sampleR)
Initialize the model. Could experiment with number of factors and regularization rate.
modelLFM = LatentFactorModel(mu, 5, 0.00001)
Training step (for the batch-based model from Chapter 5)
def trainingStep(model, interactions):
Nsamples = 50000
with tf.GradientTape() as tape:
sampleU, sampleI, sampleR = [], [], []
for _ in range(Nsamples):
u,i,r = random.choice(interactions)
sampleU.append(userIDs[u])
sampleI.append(itemIDs[i])
sampleR.append(r)
loss = model(sampleU,sampleI,sampleR)
loss += model.reg()
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients((grad, var) for
(grad, var) in zip(gradients, model.trainable_variables)
if grad is not None)
return loss.numpy()
Run 100 iterations (really 100 batches) of gradient descent
for i in range(100):
obj = trainingStep(modelLFM, interactionsTrain)
if (i % 10 == 9): print("iteration " + str(i+1) + ", objective = " + str(obj))
iteration 10, objective = 0.54607564 iteration 20, objective = 0.5176252 iteration 30, objective = 0.5195416 iteration 40, objective = 0.52042997 iteration 50, objective = 0.5179602 iteration 60, objective = 0.5080225 iteration 70, objective = 0.5091005 iteration 80, objective = 0.5124716 iteration 90, objective = 0.5148142 iteration 100, objective = 0.5149693
Prediction for a particular user/item pair
u,i,r = interactionsTest[0]
modelLFM.predict(userIDs[u], itemIDs[i]).numpy()
4.1900516
items = list(itemIDs.keys())
Batch-based version from Chapter 5
class BPRbatch(tf.keras.Model):
def __init__(self, K, lamb):
super(BPRbatch, self).__init__()
# Initialize variables
self.betaI = tf.Variable(tf.random.normal([len(itemIDs)],stddev=0.001))
self.gammaU = tf.Variable(tf.random.normal([len(userIDs),K],stddev=0.001))
self.gammaI = tf.Variable(tf.random.normal([len(itemIDs),K],stddev=0.001))
# Regularization coefficient
self.lamb = lamb
# Prediction for a single instance
def predict(self, u, i):
p = self.betaI[i] + tf.tensordot(self.gammaU[u], self.gammaI[i], 1)
return p
# Regularizer
def reg(self):
return self.lamb * (tf.nn.l2_loss(self.betaI) +\
tf.nn.l2_loss(self.gammaU) +\
tf.nn.l2_loss(self.gammaI))
def score(self, sampleU, sampleI):
u = tf.convert_to_tensor(sampleU, dtype=tf.int32)
i = tf.convert_to_tensor(sampleI, dtype=tf.int32)
beta_i = tf.nn.embedding_lookup(self.betaI, i)
gamma_u = tf.nn.embedding_lookup(self.gammaU, u)
gamma_i = tf.nn.embedding_lookup(self.gammaI, i)
x_ui = beta_i + tf.reduce_sum(tf.multiply(gamma_u, gamma_i), 1)
return x_ui
def call(self, sampleU, sampleI, sampleJ):
x_ui = self.score(sampleU, sampleI)
x_uj = self.score(sampleU, sampleJ)
return -tf.reduce_mean(tf.math.log(tf.math.sigmoid(x_ui - x_uj)))
optimizer = tf.keras.optimizers.Adam(0.1)
modelBPR = BPRbatch(5, 0.00001)
def trainingStepBPR(model, interactions):
Nsamples = 50000
with tf.GradientTape() as tape:
sampleU, sampleI, sampleJ = [], [], []
for _ in range(Nsamples):
u,i,_ = random.choice(interactions) # positive sample
j = random.choice(items) # negative sample
while j in itemsPerUser[u]:
j = random.choice(items)
sampleU.append(userIDs[u])
sampleI.append(itemIDs[i])
sampleJ.append(itemIDs[j])
loss = model(sampleU,sampleI,sampleJ)
loss += model.reg()
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients((grad, var) for
(grad, var) in zip(gradients, model.trainable_variables)
if grad is not None)
return loss.numpy()
Run 100 batches of gradient descent
for i in range(100):
obj = trainingStepBPR(modelBPR, interactions)
if (i % 10 == 9): print("iteration " + str(i+1) + ", objective = " + str(obj))
iteration 10, objective = 0.53013414 iteration 20, objective = 0.47684863 iteration 30, objective = 0.47061586 iteration 40, objective = 0.47537118 iteration 50, objective = 0.4760192 iteration 60, objective = 0.47407478 iteration 70, objective = 0.47333395 iteration 80, objective = 0.4722594 iteration 90, objective = 0.4723976 iteration 100, objective = 0.4724716
Prediction for a particular user/item pair. Note that this is an unnormalized score (which can be used for ranking)
u,i,_ = interactionsTest[0]
# In this case just a score (that can be used for ranking), rather than a prediction of a rating
modelBPR.predict(userIDs[u], itemIDs[i]).numpy()
3.1493201
Adapt the latent factor model above, simply deleting any terms associated with latent factors
class LatentFactorModelBiasOnly(tf.keras.Model):
def __init__(self, mu, lamb):
super(LatentFactorModelBiasOnly, self).__init__()
# Initialize to average
self.alpha = tf.Variable(mu)
# Initialize to small random values
self.betaU = tf.Variable(tf.random.normal([len(userIDs)],stddev=0.001))
self.betaI = tf.Variable(tf.random.normal([len(itemIDs)],stddev=0.001))
self.lamb = lamb
# Prediction for a single instance (useful for evaluation)
def predict(self, u, i):
p = self.alpha + self.betaU[u] + self.betaI[i]
return p
# Regularizer
def reg(self):
return self.lamb * (tf.reduce_sum(self.betaU**2) +\
tf.reduce_sum(self.betaI**2))
# Prediction for a sample of instances
def predictSample(self, sampleU, sampleI):
u = tf.convert_to_tensor(sampleU, dtype=tf.int32)
i = tf.convert_to_tensor(sampleI, dtype=tf.int32)
beta_u = tf.nn.embedding_lookup(self.betaU, u)
beta_i = tf.nn.embedding_lookup(self.betaI, i)
pred = self.alpha + beta_u + beta_i
return pred
# Loss
def call(self, sampleU, sampleI, sampleR):
pred = self.predictSample(sampleU, sampleI)
r = tf.convert_to_tensor(sampleR, dtype=tf.float32)
return tf.nn.l2_loss(pred - r) / len(sampleR)
modelBiasOnly = LatentFactorModelBiasOnly(mu, 0.00001)
def trainingStepBiasOnly(model, interactions):
Nsamples = 50000
with tf.GradientTape() as tape:
sampleU, sampleI, sampleR = [], [], []
for _ in range(Nsamples):
u,i,r = random.choice(interactions)
sampleU.append(userIDs[u])
sampleI.append(itemIDs[i])
sampleR.append(r)
loss = model(sampleU,sampleI,sampleR)
loss += model.reg()
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients((grad, var) for
(grad, var) in zip(gradients, model.trainable_variables)
if grad is not None)
return loss.numpy()
for i in range(50):
obj = trainingStepBiasOnly(modelBiasOnly, interactionsTrain)
if (i % 10 == 9): print("iteration " + str(i+1) + ", objective = " + str(obj))
iteration 10, objective = 0.5770031 iteration 20, objective = 0.536491 iteration 30, objective = 0.5330345 iteration 40, objective = 0.52337563 iteration 50, objective = 0.5127952
Compute the MSEs for a model which always predicts the mean, versus one which involves bias terms
def MSE(predictions, labels):
differences = [(x-y)**2 for x,y in zip(predictions,labels)]
return sum(differences) / len(differences)
alwaysPredictMean = [mu for _ in interactionsTest]
labels = [r for _,_,r in interactionsTest]
MSE(alwaysPredictMean, labels)
1.3209473208629037
biasOnlyPredictions =\
[modelBiasOnly.predict(userIDs[u],itemIDs[i]).numpy() for u,i,_ in interactionsTest]
biasOnlyPredictions[0]
4.880166
MSE(biasOnlyPredictions, labels)
1.004844762789679
Performance of a complete latent factor model (using the latent factor model implementation in the examples above)
optimizer = tf.keras.optimizers.Adam(0.1)
modelLFM = LatentFactorModel(mu, 10, 0.00001)
for i in range(50):
obj = trainingStep(modelLFM, interactionsTrain)
if (i % 10 == 9): print("iteration " + str(i+1) + ", objective = " + str(obj))
iteration 10, objective = 0.53806317 iteration 20, objective = 0.526853 iteration 30, objective = 0.5331862 iteration 40, objective = 0.5367909 iteration 50, objective = 0.5340424
predictions = [modelLFM.predict(userIDs[u],itemIDs[i]).numpy() for u,i,_ in interactionsTest]
MSE(predictions, labels)
1.0092721226433803
(probably needs a little more tuning in terms of number of latent factors, learning rate, etc.)
Experiment with rounding the predictions
predictionsRounded = [int(p + 0.5) for p in predictions]
MSE(predictionsRounded, labels)
1.0957701810672271
Seems to result in worse performance. For a rough explanation, consider a random variable that takes a value of "1" half the time and "2" half the time; in terms of the MSE, always predicting 1.5 (and always incurring moderate errors) is preferable to always predicting either of 1 or 2 (and incurring a large error half the time).
Following the BPR code from examples above
optimizer = tf.keras.optimizers.Adam(0.1)
modelBPR = BPRbatch(10, 0.00001)
for i in range(50):
obj = trainingStepBPR(modelBPR, interactionsTrain)
if (i % 10 == 9): print("iteration " + str(i+1) + ", objective = " + str(obj))
iteration 10, objective = 0.5285175 iteration 20, objective = 0.48673838 iteration 30, objective = 0.48348132 iteration 40, objective = 0.4870032 iteration 50, objective = 0.49010223
interactionsTestPerUser = defaultdict(set)
itemSet = set()
for u,i,_ in interactionsTest:
interactionsTestPerUser[u].add(i)
itemSet.add(i)
AUC implementation
def AUCu(u, N): # N samples per user
win = 0
if N > len(interactionsTestPerUser[u]):
N = len(interactionsTestPerUser[u])
positive = random.sample(interactionsTestPerUser[u],N)
negative = random.sample(itemSet.difference(interactionsTestPerUser[u]),N)
for i,j in zip(positive,negative):
si = modelBPR.predict(userIDs[u], itemIDs[i]).numpy()
sj = modelBPR.predict(userIDs[u], itemIDs[j]).numpy()
if si > sj:
win += 1
return win/N
def AUC():
av = []
for u in interactionsTestPerUser:
av.append(AUCu(u, 10))
return sum(av) / len(av)
AUC()
0.7953086008585195