In [1]:
import tensorflow as tf
import array
import gzip
import random
from tensorflow.keras import Model
from collections import defaultdict
In [2]:
def parse(path):
    g = gzip.open(path, 'r')
    for l in g:
        yield eval(l)
In [3]:
userIDs = {}
itemIDs = {}
interactions = []

# Could adapt to any dataset, this one is from
# https://sites.google.com/eng.ucsd.edu/ucsdbookgraph/home
for d in parse("goodreads_reviews_comics_graphic.json.gz"):
    u = d['user_id']
    i = d['book_id']
    r = d['rating']
    if not u in userIDs: userIDs[u] = len(userIDs)
    if not i in itemIDs: itemIDs[i] = len(itemIDs)
    interactions.append((u,i,r))
In [4]:
len(interactions)
Out[4]:
542338

Latent Factor Model

In [5]:
# mean rating, just for initialization
mu = sum([r for _,_,r in interactions]) / len(interactions)
In [6]:
# Gradient descent optimizer, experiment with learning rate
optimizer = tf.keras.optimizers.Adam(0.001)
In [7]:
class LatentFactorModel(tf.keras.Model):
    def __init__(self, mu, K, lamb):
        super(LatentFactorModel, self).__init__()
        # Initialize to average
        self.alpha = tf.Variable(mu)
        # Initialize to small random values
        self.betaU = tf.Variable(tf.random.normal([len(userIDs)],stddev=0.001))
        self.betaI = tf.Variable(tf.random.normal([len(itemIDs)],stddev=0.001))
        self.gammaU = tf.Variable(tf.random.normal([len(userIDs),K],stddev=0.001))
        self.gammaI = tf.Variable(tf.random.normal([len(itemIDs),K],stddev=0.001))
        self.lamb = lamb

    # Prediction for a single instance (useful for evaluation)
    def predict(self, u, i):
        p = self.alpha + self.betaU[u] + self.betaI[i] +\
            tf.tensordot(self.gammaU[u], self.gammaI[i], 1)
        return p

    # Regularizer
    def reg(self):
        return self.lamb * tf.reduce_sum(self.betaU**2) +\
                           tf.reduce_sum(self.betaI**2) +\
                           tf.reduce_sum(self.gammaU**2) +\
                           tf.reduce_sum(self.gammaI**2)
    
    # Prediction for a sample of instances
    def predictSample(self, sampleU, sampleI):
        u = tf.convert_to_tensor(sampleU, dtype=tf.int32)
        i = tf.convert_to_tensor(sampleI, dtype=tf.int32)
        beta_u = tf.nn.embedding_lookup(self.betaU, u)
        beta_i = tf.nn.embedding_lookup(self.betaI, i)
        gamma_u = tf.nn.embedding_lookup(self.gammaU, u)
        gamma_i = tf.nn.embedding_lookup(self.gammaI, i)
        pred = self.alpha + beta_u + beta_i +\
               tf.reduce_sum(tf.multiply(gamma_u, gamma_i), 1)
        return pred
    
    # Loss
    def call(self, sampleU, sampleI, sampleR):
        pred = self.predictSample(sampleU, sampleI)
        r = tf.convert_to_tensor(sampleR, dtype=tf.float32)
        return tf.nn.l2_loss(pred - r) / len(sampleR)
In [8]:
# Experiment with number of factors and regularization rate
model = LatentFactorModel(mu, 5, 0.00001)
In [9]:
def trainingStep(interactions):
    Nsamples = 50000
    with tf.GradientTape() as tape:
        sampleU, sampleI, sampleR = [], [], []
        for _ in range(Nsamples):
            u,i,r = random.choice(interactions)
            sampleU.append(userIDs[u])
            sampleI.append(itemIDs[i])
            sampleR.append(r)

        loss = model(sampleU,sampleI,sampleR)
        loss += model.reg()
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients((grad, var) for
                              (grad, var) in zip(gradients, model.trainable_variables)
                              if grad is not None)
    return loss.numpy()
In [10]:
# 10 iterations of gradient descent
for i in range(10):
    obj = trainingStep(interactions)
    print("iteration " + str(i) + ", objective = " + str(obj))
iteration 0, objective = 1.4984645
iteration 1, objective = 0.991089
iteration 2, objective = 0.84890044
iteration 3, objective = 0.91853225
iteration 4, objective = 0.9871618
iteration 5, objective = 0.91327095
iteration 6, objective = 0.81601316
iteration 7, objective = 0.77269465
iteration 8, objective = 0.7903811
iteration 9, objective = 0.7932894
In [11]:
# Prediction (e.g.):
In [12]:
interactions[0]
Out[12]:
('dc3763cdb9b2cae805882878eebb6a32', '18471619', 3)
In [13]:
model.predict(userIDs['dc3763cdb9b2cae805882878eebb6a32'], itemIDs['18471619']).numpy()
Out[13]:
3.7737765

Bayesian Personalized Ranking (One-Class model)

In [14]:
itemsPerUser = defaultdict(set)
for u,i,_ in interactions:
    itemsPerUser[u].add(i)
In [15]:
items = list(itemIDs.keys())
In [16]:
# Experiment with learning rate
optimizer = tf.keras.optimizers.Adam(0.001)
In [17]:
class BPRbatch(tf.keras.Model):
    def __init__(self, K, lamb):
        super(BPRbatch, self).__init__()
        # Initialize variables
        self.betaI = tf.Variable(tf.random.normal([len(itemIDs)],stddev=0.001))
        self.gammaU = tf.Variable(tf.random.normal([len(userIDs),K],stddev=0.001))
        self.gammaI = tf.Variable(tf.random.normal([len(itemIDs),K],stddev=0.001))
        # Regularization coefficient
        self.lamb = lamb

    # Prediction for a single instance
    def predict(self, u, i):
        p = self.betaI[i] + tf.tensordot(self.gammaU[u], self.gammaI[i], 1)
        return p

    # Regularizer
    def reg(self):
        return self.lamb * tf.nn.l2_loss(self.betaI) +\
                           tf.nn.l2_loss(self.gammaU) +\
                           tf.nn.l2_loss(self.gammaI)
    
    def score(self, sampleU, sampleI):
        u = tf.convert_to_tensor(sampleU, dtype=tf.int32)
        i = tf.convert_to_tensor(sampleI, dtype=tf.int32)
        beta_i = tf.nn.embedding_lookup(self.betaI, i)
        gamma_u = tf.nn.embedding_lookup(self.gammaU, u)
        gamma_i = tf.nn.embedding_lookup(self.gammaI, i)
        x_ui = beta_i + tf.reduce_sum(tf.multiply(gamma_u, gamma_i), 1)
        return x_ui

    def call(self, sampleU, sampleI, sampleJ):
        x_ui = self.score(sampleU, sampleI)
        x_uj = self.score(sampleU, sampleJ)
        return -tf.reduce_mean(tf.math.log(tf.math.sigmoid(x_ui - x_uj)))

model = BPRbatch(5, 0.00001)
In [18]:
def trainingStep(interactions):
    Nsamples = 50000
    with tf.GradientTape() as tape:
        sampleU, sampleI, sampleJ = [], [], []
        for _ in range(Nsamples):
            u,i,_ = random.choice(interactions) # positive sample
            j = random.choice(items) # negative sample
            while j in itemsPerUser[u]:
                j = random.choice(items)
            sampleU.append(userIDs[u])
            sampleI.append(itemIDs[i])
            sampleJ.append(itemIDs[j])

        loss = model(sampleU,sampleI,sampleJ)
        loss += model.reg()
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients((grad, var) for
                              (grad, var) in zip(gradients, model.trainable_variables)
                              if grad is not None)
    return loss.numpy()
In [19]:
# Run training iterations (probably want more than 10...)
for i in range(10):
    obj = trainingStep(interactions)
    print("iteration " + str(i) + ", objective = " + str(obj))
iteration 0, objective = 1.0645976
iteration 1, objective = 0.83555615
iteration 2, objective = 0.7738053
iteration 3, objective = 0.8066677
iteration 4, objective = 0.82900363
iteration 5, objective = 0.7990833
iteration 6, objective = 0.7587987
iteration 7, objective = 0.743567
iteration 8, objective = 0.7486358
iteration 9, objective = 0.74816036
In [20]:
# In this case just a score, rather than a prediction of a rating
model.predict(userIDs['dc3763cdb9b2cae805882878eebb6a32'], itemIDs['18471619']).numpy()
Out[20]:
0.008223855