The Yelp Dataset Challenge dataset provides information on a subset of its
It includes data from 10 cities worldwide, but the size of the data from each city differs widely
e.g. there are 36,500 businesses in Las Vegas, but only 530 in Toronto
To build a recommendation system, we first must clean the data and establish a clear goal for recommendations
There are many types of businesses categories in the dataset. These categories are included as tags in the data and businesses can have multiple tags.
# Number of business in total
business_df['business_id'].count()
85901
To get an idea of the types of busineesses included in the dataset, we can look at the "Categories" tag in the Businesses table.
categories, num_categories, top_categories = common_categories(business_df)
print("There are {} unique business categories in the dataset".format(num_categories))
print("The top categories by count are")
print('\n'.join('\t{}:{}'.format(*k) for i,k in enumerate(top_categories)))
There are 1017 unique business categories in the dataset The top categories by count are Restaurants:26729 Shopping:12444 Food:10143 Beauty & Spas:7490 Health & Medical:6106 Home Services:5866 Nightlife:5507 Automotive:4888 Bars:4727 Local Services:4041 Active Life:3455 Fashion:3395 Event Planning & Services:3237 Fast Food:3154 Pizza:2881 Mexican:2705 Hotels & Travel:2673 Sandwiches:2666 American (Traditional):2608 Arts & Entertainment:2447
Next, we explore the data to select the best city with which to work. Because many cities are actually larger metropolitan areas that include several cities, we divide the data by state and look at the counts by state. We would like choose a location that has the most businesses, and restaurants in particular, but also has the most reviews for those businesses (restaurants).
# Business counts
business_location_counts, top_restaurant_counts = data_location(business_df, business_df)
print("Number of all business and restaurant by state")
print('\n'.join('\t{}:{}'.format(*k) for i,k in enumerate(top_restaurant_counts[:12])))
# Reviews counts
review_location_counts, top_restaurant_review_counts = data_location(review_df, business_df)
print("Number of reviews for all business and restaurant by state")
print('\n'.join('\t{}:{}'.format(*k) for i,k in enumerate(top_restaurant_review_counts[:12])))
# Check-in counts
checkin_location_counts, top_restaurant_checkin_counts = data_location(checkin_df, business_df)
print("Number of check-ins for all business and restaurant by state")
print('\n'.join('\t{}:{}'.format(*k) for i,k in enumerate(top_restaurant_checkin_counts[:12])))
# Tip counts
tips_location_counts, top_restaurant_tips_counts = data_location(tip_df, business_df)
print("Number of tips for all business and restaurant by state")
print('\n'.join('\t{}:{}'.format(*k) for i,k in enumerate(top_restaurant_tips_counts[:12])))
Number of all business and restaurant by state AZ:{'all': 36500, 'restaurants': 9427} NV:{'all': 23591, 'restaurants': 5912} QC:{'all': 5591, 'restaurants': 3385} NC:{'all': 6835, 'restaurants': 2421} PA:{'all': 4086, 'restaurants': 1671} EDH:{'all': 3297, 'restaurants': 1232} WI:{'all': 3066, 'restaurants': 1172} BW:{'all': 1055, 'restaurants': 571} ON:{'all': 530, 'restaurants': 372} IL:{'all': 808, 'restaurants': 317} SC:{'all': 327, 'restaurants': 143} MLN:{'all': 161, 'restaurants': 75} Number of reviews for all business and restaurant by state NV:{'all': 1154799, 'restaurants': 662428} AZ:{'all': 1029103, 'restaurants': 622446} NC:{'all': 165625, 'restaurants': 112794} PA:{'all': 111542, 'restaurants': 78754} QC:{'all': 88046, 'restaurants': 63987} WI:{'all': 69917, 'restaurants': 49243} EDH:{'all': 30003, 'restaurants': 15189} IL:{'all': 19045, 'restaurants': 13054} ON:{'all': 5854, 'restaurants': 4814} SC:{'all': 5004, 'restaurants': 3532} BW:{'all': 3721, 'restaurants': 2714} TX:{'all': 911, 'restaurants': 899} Number of check-ins for all business and restaurant by state AZ:{'all': 24876, 'restaurants': 8570} NV:{'all': 17948, 'restaurants': 5505} QC:{'all': 4343, 'restaurants': 2741} NC:{'all': 5300, 'restaurants': 2212} PA:{'all': 3061, 'restaurants': 1457} WI:{'all': 2154, 'restaurants': 1018} EDH:{'all': 1899, 'restaurants': 822} ON:{'all': 398, 'restaurants': 297} IL:{'all': 491, 'restaurants': 259} BW:{'all': 214, 'restaurants': 160} SC:{'all': 229, 'restaurants': 125} MLN:{'all': 115, 'restaurants': 58} Number of tips for all business and restaurant by state NV:{'all': 313993, 'restaurants': 183039} AZ:{'all': 245449, 'restaurants': 156933} NC:{'all': 35955, 'restaurants': 24222} PA:{'all': 19056, 'restaurants': 12392} QC:{'all': 13776, 'restaurants': 9526} WI:{'all': 10850, 'restaurants': 7660} EDH:{'all': 5159, 'restaurants': 2533} IL:{'all': 1983, 'restaurants': 1378} SC:{'all': 959, 'restaurants': 673} ON:{'all': 825, 'restaurants': 652} TX:{'all': 349, 'restaurants': 348} MLN:{'all': 302, 'restaurants': 176}
To train and evaluate the recommender system, the dateset must be split by time. The first time period is used to train the recommender system and the later time period is used to evaluate it. To find the appropriate train-test split, the number of reviews over time is analyzed
plot_reviews_over_time(counts, cummulative_counts, bin_edges, cutoff)
if $r_{ij}$ is the rating user $i$ gave item $j$,
$$ R_{ij} = \left\{ \begin{array}{ll} r_{ij} & \text{if user $i$ reviewed item $j$} \\ 0 & \mathrm{otherwise} \\ \end{array} \right. $$# plot review counts per user
plot_review_counts(M)
average number of ratings per user: 2.982617507210227
# plot utility matrix
plot_sparse_matrix(M)
Evaluating the recommendation system will be done by comparing predicted star rating for users with reviews before and after the training/testing cut-off.
Users with rating only after the cut-off is disregarded, and not included in the utility matrix, since training on "future" data is not realistic in a "real-world" setting.
The utility matrix is constructed with the test indices masked off and the users with ratings after the cut-off removed
In content-based recommendation systems, profiles are created for items and for users.
Creating good profiles relies on designing (or learning) good features
Predictions are based on similarities between items or similarities between users
Instead of constructing item and user profiles, users are represented by their corresponding row in the utility matrix and items by their corresponding column.
User's rating of items are not random. They are goverened by ... something. We don't really know. Ratings could depend on many factors relating users, items, and external factors.
We assume there are some (latent) factors that can be regarded as responsible for how users review restaurants. Without explicitly constructing these factors, we can construct matrices to represent them.
The utility matrix $R$ can be factorized in the following way
$n$: number of users
$m$: number of items (businesses)
$k$: number of factors we allow for deciding ratings
The goal is to factor the utility matrix to reduce the error
i.e find $P$ and $Q$ so that the error is minimized
Root mean square error
$ R \approx PQ$
prediction for user $i$ on item $j$: $S_{ij} = (PQ)_{ij}$
$$ RMSE = \sqrt{\frac{1}{n} \sum_{i,j}\left|(s_{i,j} - r_{i,j})^2\right|} $$Instead of explicitly performing the $PQ$ multiplication, we can compute $S_{ij} = P_{i,:}Q_{:,j}$ only for non-zero values in the utility matrix
In linear regression, we had a data matrix $\mathbf{X}$ and a target $\mathbf{y}$ and we wanted to find a coefficient vector $\mathbf{w}$ so that
$$ \mathbf{y} = \mathbf{X}\mathbf{w} + \lambda \mathbf{w}^T\mathbf{w} $$To find $\mathbf{w}$ that minimized the MSE $\left(MSE = \frac{1}{N} \displaystyle\sum_{n=1}^N (\mathbf{w}^T\mathbf{x}_n - y_n)^2 + \lambda\|\mathbf{w}\|_2^2 \right)$ there is a closed-formed solution:
$$ \mathbf{w} = \left(\mathbf{X}^T\mathbf{X} + \lambda\mathbf{I} \right)^{-1}\mathbf{X}^Ty $$If instead of predicting a single target value, we want to predict a vector for each instance $\mathbf{x}_i$, we need to find a matrix of coefficients
and the closed-form solution is $$ \mathbf{W} = \left(\mathbf{X}^T\mathbf{X} + \lambda\mathbf{I} \right)^{-1}\mathbf{X}^T\mathbf{Y} $$
The multivariable regression problem $$\mathbf{Y}=\mathbf{X}\mathbf{W} + \lambda \|\mathbf{W}\|^2$$ looks very similar to the matrix factorization problem $$ \mathbf{R} = \mathbf{P}\mathbf{Q} + \lambda\left(\|\mathbf{P}\|^2 + \|\mathbf{Q}\|^2\right) $$
except in MF,
$$ {}$$
Therefore the objective is non-convex
and the problem is NP-hard
Using alternating least squares, we can fix one matrix, solve for the other, then use those values to fix the second matrix and solve for the first. We continue to alternate between solving both until the matrices converge
Repeat:
Now, in each step the problem is just a multivariable linear regression where the feature matrix is $\mathbf{Q}$ (or $\mathbf{P}$) and the coefficient matrix is $\mathbf{P}$ (or $\mathbf{Q}$)
modify the objective value to use weighted-$\lambda$-regularization1 $$ \mathbf{R} = \mathbf{P}\mathbf{Q} + \lambda\left(\mathbf{n}_u \|\mathbf{P}\|^2 + \mathbf{n}_b\|\mathbf{Q}\|^2\right) $$
to include $n_u$, the number of businesses each user reviewed, and $n_b$, the number of users that reviewed each business
Repeat:
[1] Zhou, Y. et al., 2008. Large-scale parallel collaborative filtering for the netflix prize. In Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture Notes in Bioinformatics). pp. 337-348.
for lmbda, k in param_pairs:
if (lmbda,k)not in results.keys() or results[(lmbda,k)]['P'] is None:
print('lambda {}, k {}'. format(lmbda,k))
P, Q, train_errors, validation_errors, test_errors = alsq(
R, M, k, lmbda, max_iters, train_indxs, validation_indxs, test_indxs)
results[(lmbda, k)] = {'train': train_errors,
'validate': validation_errors,
'test': test_errors,
'P': P,
'Q': Q}
lambda 5, k 2 [Iteration 1/20: train RMSE 3.110945782088086, validation RMSE 1.4962039263643991, test RMSE 1.5381778214157509] [Iteration 2/20: train RMSE 2.2319606259400837, validation RMSE 1.2514785329669016, test RMSE 1.3442927148632347] [Iteration 3/20: train RMSE 2.0720930170066088, validation RMSE 1.1993932138703314, test RMSE 1.2997340387729424] [Iteration 4/20: train RMSE 2.020959511903329, validation RMSE 1.182239546764226, test RMSE 1.2780946831626183] [Iteration 5/20: train RMSE 2.003961204685282, validation RMSE 1.1762542179494468, test RMSE 1.2663021076618328] [Iteration 6/20: train RMSE 1.9985774477750133, validation RMSE 1.1744357721269125, test RMSE 1.2594438569536164] [Iteration 7/20: train RMSE 1.9972978359061795, validation RMSE 1.1742838684952777, test RMSE 1.2551062663128012] [Iteration 8/20: train RMSE 1.9974686965314588, validation RMSE 1.1747077063049038, test RMSE 1.2521876988061313] [Iteration 9/20: train RMSE 1.9980903628336695, validation RMSE 1.1752541493878865, test RMSE 1.2501606715260039] [Iteration 10/20: train RMSE 1.9987550986871994, validation RMSE 1.1757812073601024, test RMSE 1.2487294827059452] [Iteration 11/20: train RMSE 1.999330471853797, validation RMSE 1.1762495411464093, test RMSE 1.2477074865890467] [Iteration 12/20: train RMSE 1.9997899916945343, validation RMSE 1.1766484810762678, test RMSE 1.2469713768483188] [Iteration 13/20: train RMSE 2.000134240024294, validation RMSE 1.1769813016006405, test RMSE 1.2464390913251007] [Iteration 14/20: train RMSE 2.0003696299829867, validation RMSE 1.1772590155511262, test RMSE 1.2460558747309587] [Iteration 15/20: train RMSE 2.0005064462150624, validation RMSE 1.1774948484174128, test RMSE 1.2457846328446571] [Iteration 16/20: train RMSE 2.000560009832354, validation RMSE 1.1777003524316605, test RMSE 1.245599340197056] [Iteration 17/20: train RMSE 2.000552070196739, validation RMSE 1.1778831637317044, test RMSE 1.2454803495385969] [Iteration 18/20: train RMSE 2.000511520346015, validation RMSE 1.1780461940583802, test RMSE 1.245410742536233] [Iteration 19/20: train RMSE 2.000467682100518, validation RMSE 1.1781886776372013, test RMSE 1.2453742137930404] [Iteration 20/20: train RMSE 2.0004363679451993, validation RMSE 1.1783097200620212, test RMSE 1.2453564107584338] lambda 5, k 5 [Iteration 1/20: train RMSE 2.2582364900115612, validation RMSE 1.2833092534270094, test RMSE 1.3797955417690915] [Iteration 2/20: train RMSE 1.9662826895473566, validation RMSE 1.2300908620547861, test RMSE 1.3312184673881102] [Iteration 3/20: train RMSE 1.9363010227663797, validation RMSE 1.2153401521798683, test RMSE 1.3128894392687194] [Iteration 4/20: train RMSE 1.9280473184812477, validation RMSE 1.2048787778880818, test RMSE 1.2976676657200243] [Iteration 5/20: train RMSE 1.924853143197236, validation RMSE 1.1976290476372307, test RMSE 1.2860932552764253] [Iteration 6/20: train RMSE 1.92333337538937, validation RMSE 1.1925514145047118, test RMSE 1.2776514366152298] [Iteration 7/20: train RMSE 1.9224464360591105, validation RMSE 1.1887904628032877, test RMSE 1.271523020158568] [Iteration 8/20: train RMSE 1.921785768017903, validation RMSE 1.1857846577545188, test RMSE 1.267013209440568] [Iteration 9/20: train RMSE 1.9211696310999096, validation RMSE 1.1832171537519374, test RMSE 1.263635280438189] [Iteration 10/20: train RMSE 1.9205262837755268, validation RMSE 1.1809504455839206, test RMSE 1.2610796384903213] [Iteration 11/20: train RMSE 1.9198416957216713, validation RMSE 1.1789563683503117, test RMSE 1.2591490947450628] [Iteration 12/20: train RMSE 1.919123462117578, validation RMSE 1.17725261491611, test RMSE 1.2577064268405733] [Iteration 13/20: train RMSE 1.918385702912294, validation RMSE 1.1758586216885023, test RMSE 1.2566473036833041] [Iteration 14/20: train RMSE 1.917645023389206, validation RMSE 1.1747726205004452, test RMSE 1.2558885024536155] [Iteration 15/20: train RMSE 1.916918043066739, validation RMSE 1.1739672041546259, test RMSE 1.2553620294471326] [Iteration 16/20: train RMSE 1.9162196100120645, validation RMSE 1.1733974871811717, test RMSE 1.2550118647741773] [Iteration 17/20: train RMSE 1.915561680380291, validation RMSE 1.173013215258407, test RMSE 1.254792256927266] [Iteration 18/20: train RMSE 1.9149523384799771, validation RMSE 1.1727679523997485, test RMSE 1.254666718684552] [Iteration 19/20: train RMSE 1.9143952958249009, validation RMSE 1.172623278910957, test RMSE 1.2546070636639388] [Iteration 20/20: train RMSE 1.9138903465067056, validation RMSE 1.172549339599572, test RMSE 1.2545921819047434] lambda 5, k 10 [Iteration 1/20: train RMSE 2.0130694031397036, validation RMSE 1.2279553300848227, test RMSE 1.3165179622979126] [Iteration 2/20: train RMSE 1.8948603745404622, validation RMSE 1.19572371242983, test RMSE 1.2843355227785749] [Iteration 3/20: train RMSE 1.887796704414249, validation RMSE 1.193692812801721, test RMSE 1.2778404978850135] [Iteration 4/20: train RMSE 1.8843489858967635, validation RMSE 1.192327397379505, test RMSE 1.2727161969943317] [Iteration 5/20: train RMSE 1.881687077187238, validation RMSE 1.1910330943678054, test RMSE 1.2687071150743914] [Iteration 6/20: train RMSE 1.879583808041864, validation RMSE 1.189800073897304, test RMSE 1.2657314025833553] [Iteration 7/20: train RMSE 1.8778585196920097, validation RMSE 1.1887761784838924, test RMSE 1.2636124317762356] [Iteration 8/20: train RMSE 1.8763669998517516, validation RMSE 1.1880310481176257, test RMSE 1.2621477256173501] [Iteration 9/20: train RMSE 1.8750213206795363, validation RMSE 1.187551268023915, test RMSE 1.2611586209468284] [Iteration 10/20: train RMSE 1.8737765300386118, validation RMSE 1.187291726103361, test RMSE 1.260507396278861] [Iteration 11/20: train RMSE 1.8726138340603038, validation RMSE 1.187202215858874, test RMSE 1.2600941428194357] [Iteration 12/20: train RMSE 1.8715278678891398, validation RMSE 1.187236673473767, test RMSE 1.2598481221369051] [Iteration 13/20: train RMSE 1.8705187307884497, validation RMSE 1.187355977030994, test RMSE 1.2597196611476074] [Iteration 14/20: train RMSE 1.8695875156859465, validation RMSE 1.1875286719007319, test RMSE 1.2596738694322136] [Iteration 15/20: train RMSE 1.8687340583163323, validation RMSE 1.1877308656943886, test RMSE 1.259686079477981] [Iteration 16/20: train RMSE 1.8679560898897953, validation RMSE 1.1879455818419582, test RMSE 1.2597386733591085] [Iteration 17/20: train RMSE 1.8672493028287453, validation RMSE 1.1881616902871428, test RMSE 1.2598189455647497] [Iteration 18/20: train RMSE 1.866607934954759, validation RMSE 1.188372588652509, test RMSE 1.2599176768412186] [Iteration 19/20: train RMSE 1.8660255137606767, validation RMSE 1.1885748632554505, test RMSE 1.26002815963478] [Iteration 20/20: train RMSE 1.8654955040972248, validation RMSE 1.1887671288619657, test RMSE 1.2601455051265287] lambda 5, k 20 [Iteration 1/20: train RMSE 2.0463711193706526, validation RMSE 1.2772807319296917, test RMSE 1.3178362054101425] [Iteration 2/20: train RMSE 1.896799734723713, validation RMSE 1.2003723277653142, test RMSE 1.267475149993136] [Iteration 3/20: train RMSE 1.878390328698135, validation RMSE 1.1897554604302627, test RMSE 1.2616417330812635] [Iteration 4/20: train RMSE 1.8690711315360404, validation RMSE 1.1851168753370913, test RMSE 1.2601665725734987] [Iteration 5/20: train RMSE 1.8628709366359777, validation RMSE 1.182724031767976, test RMSE 1.2598779056752092] [Iteration 6/20: train RMSE 1.858223179405607, validation RMSE 1.1815620040451655, test RMSE 1.2600055136355535] [Iteration 7/20: train RMSE 1.8545674845617386, validation RMSE 1.1810764615313718, test RMSE 1.2602869911927086] [Iteration 8/20: train RMSE 1.8516249256126627, validation RMSE 1.1809506460607095, test RMSE 1.2606168382853482] [Iteration 9/20: train RMSE 1.849217974816988, validation RMSE 1.1810092662972773, test RMSE 1.2609504605286195] [Iteration 10/20: train RMSE 1.8472218688666076, validation RMSE 1.1811565764320797, test RMSE 1.2612693732335383] [Iteration 11/20: train RMSE 1.8455458482143956, validation RMSE 1.1813406219193683, test RMSE 1.2615667292251087] [Iteration 12/20: train RMSE 1.8441228039458015, validation RMSE 1.1815336859345285, test RMSE 1.2618409446829162] [Iteration 13/20: train RMSE 1.8429024349926133, validation RMSE 1.1817216180979084, test RMSE 1.2620928141841692] [Iteration 14/20: train RMSE 1.8418464633209854, validation RMSE 1.1818978718658466, test RMSE 1.2623241656734427] [Iteration 15/20: train RMSE 1.840925271633097, validation RMSE 1.1820600876708942, test RMSE 1.2625372006451576] [Iteration 16/20: train RMSE 1.8401155690877569, validation RMSE 1.1822081178992743, test RMSE 1.2627341491481996] [Iteration 17/20: train RMSE 1.8393987887544572, validation RMSE 1.1823429019532823, test RMSE 1.2629170834763503] [Iteration 18/20: train RMSE 1.8387599841450062, validation RMSE 1.1824658472310547, test RMSE 1.2630878239976253] [Iteration 19/20: train RMSE 1.8381870507103393, validation RMSE 1.182578503837484, test RMSE 1.2632479046674974] [Iteration 20/20: train RMSE 1.8376701531777342, validation RMSE 1.182682401635325, test RMSE 1.2633985777785317]
plot_results_reg(results,lmbdas,ks)
# Baseline
R = scipy.sparse.csc_matrix(R,dtype=np.uint32) #fix sum overflow!!!
count_nnz = R.getnnz(0)
sums = np.array(R.sum(0)).flatten()
means = np.zeros(len(count_nnz))
for i,c in enumerate(count_nnz):
if c != 0:
means[i] = sums[i]/c
overall_mean = np.mean(means[means!=0])
means[means==0]=overall_mean
avg_errors = []
for i,j in test_indxs:
actual = M[i,j]
avg = means[j]
if avg==0:
avg = overall_mean
avgerr = rmse_pred(actual, avg)
avg_errors += [avgerr]
# Average Baseline error
bl_err = np.mean(avg_errors)
print('Average Baseline MSE: {:.4f}'.format(bl_err))
Average Baseline MSE: 1.1550
plt.figure()
plt.hist(avg_errors)
plt.xlabel('MSE')
plt.title('MSE for test set')
<matplotlib.text.Text at 0x888f27f0>
#Calculate MSE for each parameter pair
MF_errors = {(l,k):[] for l,k in results.keys()}
MFVal_errors = {(l,k):[] for l,k in results.keys()}
#Matrix Factorization
for l,k in param_pairs:
if results[(l,k)]['P'] is not None:
P = results[(l,k)]['P']
Q = results[(l,k)]['Q']
for i,j in validation_indxs:
actual = M[i,j]
MF = predict(P,Q,i,j)
MFerr = rmse_pred(actual, MF)
if len(MFerr)>0:
MFVal_errors[l,k] += [MFerr[0]]
for i,j in test_indxs:
actual = M[i,j]
MF = predict(P,Q,i,j)
MFerr = rmse_pred(actual, MF)
if len(MFerr)>0:
MF_errors[l,k] += [MFerr[0]]
valerr = np.sqrt(np.sum(MFVal_errors[l,k])/len(MFVal_errors[l,k]))
testerr = np.sqrt(np.sum(MF_errors[l,k])/len(MF_errors[l,k]))
print('lambda = {}, k = {} \t| val error {:.4f} \t| test error {:.4f}'.format(l,k,valerr, testerr))
lambda = 0.1, k = 2 | val error 1.0147 | test error 1.2590 lambda = 0.1, k = 5 | val error 1.0973 | test error 1.2806 lambda = 0.1, k = 10 | val error 1.1478 | test error 1.2269 lambda = 0.1, k = 20 | val error 1.1303 | test error 1.1761 lambda = 1, k = 2 | val error 0.9807 | test error 1.1531 lambda = 1, k = 5 | val error 1.0564 | test error 1.2341 lambda = 1, k = 10 | val error 1.1174 | test error 1.2155 lambda = 1, k = 20 | val error 1.0953 | test error 1.1896 lambda = 2, k = 2 | val error 0.9940 | test error 1.1375 lambda = 2, k = 5 | val error 1.0471 | test error 1.1889 lambda = 2, k = 10 | val error 1.0767 | test error 1.1943 lambda = 2, k = 20 | val error 1.0682 | test error 1.1877 lambda = 5, k = 2 | val error 1.0715 | test error 1.2309 lambda = 5, k = 5 | val error 1.0663 | test error 1.2385 lambda = 5, k = 10 | val error 1.0795 | test error 1.2454 lambda = 5, k = 20 | val error 1.0749 | test error 1.2492 lambda = 10, k = 2 | val error 1.1944 | test error 1.4074 lambda = 10, k = 5 | val error 1.1900 | test error 1.4070 lambda = 10, k = 10 | val error 1.1902 | test error 1.4079 lambda = 10, k = 20 | val error 1.1904 | test error 1.4079
#effect of lambda
fig = plt.figure()
for k in ks:
errs = [np.sqrt(np.sum(MFVal_errors[l,k])/len(MFVal_errors[l,k])) for l in lmbdas]
plt.plot(lmbdas,errs,label='k={}'.format(k))
plt.legend(loc=4)
plt.title('Validataion MSE')
plt.xlabel('$\lambda$')
<matplotlib.text.Text at 0x8efff2e8>
#effect of k
fig = plt.figure()
for l in lmbdas:
errs = [np.sqrt(np.sum(MFVal_errors[l,k])/len(MFVal_errors[l,k])) for k in ks]
plt.plot(ks,errs,label='$\lambda$={}'.format(l))
plt.legend(loc=4)
plt.title('Validataion MSE')
plt.xlabel('$k$')
<matplotlib.text.Text at 0x9c9166a0>