In this article, we will build a lambdarank algorithm for anime recommendations. A research group first introduced LambdaRank at Microsoft, and now it's available on Microsoft's LightGBM library with an easy-to-use sklearn wrapper. Let's start.
From Search Engines to Product Recommendations As mentioned above, ranking is widely used in search engines. But it's not limited to search engines; we can adopt the concept and build a solution whenever applicable. Assuming that we want to develop a ranking model for a search engine, we should start with a dataset with queries, its associated documents(URLs), and the relevance score for each query document pair, as shown below.
Finally, we can derive features based on each query and document pair.
This is the dataset format for well-known Learning to Rank research papers and datasets. Well, that's about search engines. Let's discuss how we can adapt these concepts for traditional product recommendation tasks. There are various ways to recommend products to users. You can show recommendations when the user perches an item, or while a user is browsing the page, etc. But for simplicity, let's narrow it down to a specific scenario.
This article will build an anime recommendation model for users' homepage customization. When the user logs into the account, we need to show animes based on the relevance scores predicted by the ranker model. Here I will use Anime Recommendation LTR Dataset, which is available under a public license. You can access it via this kaggle link.
import pandas as pd
import numpy as np
import zipfile
import matplotlib.pyplot as plt
import warnings
from sklearn import preprocessing,model_selection,metrics
from lightgbm import LGBMRanker
warnings.filterwarnings('ignore')
from tabulate import tabulate
from functools import wraps
def print_tabulate(df):
if isinstance(df, pd.DataFrame):
print(tabulate(df, headers='keys', tablefmt='psql'))
else:
return df
!kaggle datasets download -d ransakaravihara/anime-recommendation-ltr-dataset
anime-recommendation-ltr-dataset.zip: Skipping, found more recently modified local copy (use --force to force download)
zipped_data = zipfile.ZipFile("anime-recommendation-ltr-dataset.zip")
zipped_data.namelist()
['anime_info.csv', 'relavence_scores.csv', 'user_info.csv']
Let's read the dataset and quickly check its columns.
anime_info_df = pd.read_csv(zipped_data.open('anime_info.csv'))
relavence_scores = pd.read_csv(zipped_data.open('relavence_scores.csv'))
user_info = pd.read_csv(zipped_data.open('user_info.csv'))
relavence_scores = relavence_scores[~(relavence_scores['user_id']==11100)]
anime_info_df.head()
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
anime_id | Genres | is_tv | year_aired | is_adult | above_five_star_users | above_five_star_ratings | above_five_star_ratio | |
---|---|---|---|---|---|---|---|---|
0 | 1 | Action, Adventure, Comedy, Drama, Sci-Fi, Space | 1 | 1998.0 | 1 | 4012.0 | 4012.0 | 0.594018 |
1 | 5 | Action, Drama, Mystery, Sci-Fi, Space | 0 | 2001.0 | 1 | NaN | NaN | NaN |
2 | 6 | Action, Sci-Fi, Adventure, Comedy, Drama, Shounen | 1 | 1998.0 | 0 | NaN | NaN | NaN |
3 | 7 | Action, Mystery, Police, Supernatural, Drama, ... | 1 | 2002.0 | 0 | NaN | NaN | NaN |
4 | 8 | Adventure, Fantasy, Shounen, Supernatural | 1 | 2004.0 | 0 | 76.0 | 76.0 | 0.481013 |
Alright, let me explain the methodology I am going to use here. Unlike search engine data, we don't have query and document pairs in this use case. So we will treat users as queries and the animes they are interested in as documents. One search query can be associated with multiple documents; users can interact with many animes. You got the idea, right :)
The target label we predict here is relevance_score, stored in the relevance_scores dataset. When we build a raking model, it will learn a function to rank those anime into an optimum order where the highest relevant anime comes first for each user.
Next, we need to create a dataset by joining the above three datasets. I will also create new features based on anime and user features. Let's create a dataset.
popular_genres = ['Comedy',
'Action',
'Fantasy',
'Adventure',
'Kids',
'Drama',
'Sci-Fi',
'Music',
'Shounen',
'Slice of Life']
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
def create_genre_flags(df, popular_genres):
df = df.dropna(subset=['Genres'])
df['Genres'] = df['Genres'].apply(lambda x:",".join(s.strip() for s in x.split(",")))
# use MultiLabelBinarizer to create a one-hot encoded dataframe of the genres
mlb = MultiLabelBinarizer()
genre_df = pd.DataFrame(mlb.fit_transform(df['Genres'].str.split(',')),
columns=mlb.classes_,
index=df.index)
# create a new dataframe with the movie id and genre columns
new_df = pd.concat([df['anime_id'], genre_df[popular_genres]], axis=1)
new_df.columns = ['anime_id'] + popular_genres
return new_df
anime_genre_info_df = create_genre_flags(anime_info_df,popular_genres)
anime_info_df_final = anime_info_df.merge(anime_genre_info_df,on='anime_id')
del anime_info_df_final['Genres']
anime_info_df_final.columns = [col if col=='anime_id' else f"ANIME_FEATURE {col}".upper() for col in anime_info_df_final.columns]
user_info.columns = [col if col=='user_id' else f"USER_FEATURE {col}".upper() for col in user_info.columns]
def create_histogram(data, xname, n_bins=100, background_color='#F3F3F3', title=None):
import matplotlib as mpl
mpl.rcParams['font.size'] = 15
mpl.rcParams['font.family'] = 'sans-serif'
# mpl.rcParams['font.weight'] = 'semibold'
plt.figure(figsize=(10,8))
plt.hist(data, bins=n_bins, color='#2596be', edgecolor='black')
plt.xlabel(xname)
plt.ylabel('Frequency')
ax = plt.gca()
ax.set_facecolor(background_color)
plt.grid(True, which='both', color='lightgray', linewidth=0.5)
plt.title(title.upper(),fontweight='semibold',fontfamily='sans-serif')
plt.show()
Let's quickly check a few statistics of our dataset. In total, there are 4.8Mn user-anime interactions, 15K users, and 16K animes in the dataset.
Here is the user and anime interaction distribution.
create_histogram(relavence_scores.groupby("user_id")['user_id'].agg('count'),xname='user_id', n_bins=30, title='Animes pre user')
create_histogram(relavence_scores.groupby("anime_id")['anime_id'].agg('count'),xname='anime_id', n_bins=30, title='Users per anime')
f"""there are {relavence_scores['user_id'].nunique():,.0f} unique users and {relavence_scores['anime_id'].nunique():,.0f} unique animes and {len(relavence_scores):,.0f} user-anime interactions"""
'there are 15,162 unique users and 16,109 unique animes and 4,864,759 user-anime interactions'
The below chart shows how the relevance score is distributed.
create_histogram(relavence_scores['relavence_score'],xname='relevance_scores', n_bins=30, title='Relevance Score Distribution')
Now we have created the dataset for model training. But one of the main confusing points of the Learning to Rank model is the group parameter. Because the group is a strange parameter for us since it is not used in other machine-learning algorithms. The idea of group parameter is partitioning the dataset for each query and document pair. It enables the model to learn the relative importance of different features within each group, which can improve the model's overall performance. Having ten user-anime pairs means we have ten groups in our LTR model. Each group can be different in size. But the summation groups should be the number of samples we have. In simple words, we have an input model to group boundaries using group parameters. Then the model can identify each user-anime instance separately.
For example, if you have a 6-anime dataset with group = [3, 2, 1], which means that you have three groups, where the first three records are in the first group, records 4–5 are in the second group, and record 6 is in the third group.
So it's vital to sort the dataset by user_id before creating the group parameter.
train_interim = relavence_scores.merge(anime_info_df_final)
train = train_interim.merge(user_info,how='inner')
na_counts = (train.isna().sum() * 100/len(train))
train_processed = train.drop(na_counts[na_counts > 50].index,axis=1)
train_processed.sort_values(by='user_id',inplace=True)
train_processed.columns
Index(['anime_id', 'Name', 'user_id', 'relavence_score', 'ANIME_FEATURE IS_TV',
'ANIME_FEATURE YEAR_AIRED', 'ANIME_FEATURE IS_ADULT',
'ANIME_FEATURE ABOVE_FIVE_STAR_USERS',
'ANIME_FEATURE ABOVE_FIVE_STAR_RATINGS',
'ANIME_FEATURE ABOVE_FIVE_STAR_RATIO', 'ANIME_FEATURE COMEDY',
'ANIME_FEATURE ACTION', 'ANIME_FEATURE FANTASY',
'ANIME_FEATURE ADVENTURE', 'ANIME_FEATURE KIDS', 'ANIME_FEATURE DRAMA',
'ANIME_FEATURE SCI-FI', 'ANIME_FEATURE MUSIC', 'ANIME_FEATURE SHOUNEN',
'ANIME_FEATURE SLICE OF LIFE', 'USER_FEATURE REVIEW_COUNT',
'USER_FEATURE AVG_SCORE', 'USER_FEATURE SCORE_STDDEV',
'USER_FEATURE ABOVE_FIVE_STAR_COUNT',
'USER_FEATURE ABOVE_FIVE_STAR_RATIO'],
dtype='object')
train_processed.set_index("user_id",inplace=True)
features = ['ANIME_FEATURE IS_TV',
'ANIME_FEATURE YEAR_AIRED', 'ANIME_FEATURE IS_ADULT',
'ANIME_FEATURE ABOVE_FIVE_STAR_USERS',
'ANIME_FEATURE ABOVE_FIVE_STAR_RATINGS',
'ANIME_FEATURE ABOVE_FIVE_STAR_RATIO', 'ANIME_FEATURE COMEDY',
'ANIME_FEATURE ACTION', 'ANIME_FEATURE FANTASY',
'ANIME_FEATURE ADVENTURE', 'ANIME_FEATURE KIDS', 'ANIME_FEATURE DRAMA',
'ANIME_FEATURE SCI-FI', 'ANIME_FEATURE MUSIC', 'ANIME_FEATURE SHOUNEN',
'ANIME_FEATURE SLICE OF LIFE', 'USER_FEATURE REVIEW_COUNT',
'USER_FEATURE AVG_SCORE', 'USER_FEATURE SCORE_STDDEV',
'USER_FEATURE ABOVE_FIVE_STAR_COUNT',
'USER_FEATURE ABOVE_FIVE_STAR_RATIO']
target = 'relavence_score'
from sklearn.model_selection import train_test_split
from sklearn.metrics import ndcg_score,coverage_error
test_size = int(1e5)
X,y = train_processed[features],train_processed[target].apply(lambda x:int(x*10))
test_idx_start = len(X)-test_size
xtrain,xtest,ytrain,ytest = X.iloc[0:test_idx_start],X.iloc[test_idx_start:],y.iloc[0:test_idx_start],y.iloc[test_idx_start:]
get_group_size = lambda df: df.reset_index().groupby("user_id")['user_id'].count()
train_groups = get_group_size(xtrain)
test_groups = get_group_size(xtest)
sum(train_groups) , sum(test_groups)
(4764372, 100000)
In the above code snippet, I have done the following steps,
- Dropped all the columns which have more than 50% null values.
- Sorted dataset based on user_id. Otherwise, a group parameter will have inconsistent groups everywhere.
- After sorting the dataset, I declared the last 100,000 rows as validation data and the rest as training data.
- Since the LightGBM model expects integers as target values, I have scaled the target between 1–10 and converted it to an integer. ( If you want, try converting it to 0 or 1 based on the threshold).
- Defined the groups parameter for train and tes data
model = LGBMRanker(objective="lambdarank",n_estimators=20)
model.fit(xtrain,ytrain,group=train_groups,eval_set=[(xtest,ytest)],eval_group=[test_groups],eval_metric=['map'])
[1] valid_0's map@1: 0.863354 valid_0's map@2: 0.84705 valid_0's map@3: 0.825397 valid_0's map@4: 0.816447 valid_0's map@5: 0.809951 valid_0's ndcg@1: 0.854927 valid_0's ndcg@2: 0.857799 valid_0's ndcg@3: 0.849489 valid_0's ndcg@4: 0.845283 valid_0's ndcg@5: 0.844731
[2] valid_0's map@1: 0.875776 valid_0's map@2: 0.863354 valid_0's map@3: 0.8505 valid_0's map@4: 0.832665 valid_0's map@5: 0.829016 valid_0's ndcg@1: 0.860415 valid_0's ndcg@2: 0.870321 valid_0's ndcg@3: 0.867802 valid_0's ndcg@4: 0.857057 valid_0's ndcg@5: 0.857424
[3] valid_0's map@1: 0.881988 valid_0's map@2: 0.865683 valid_0's map@3: 0.848948 valid_0's map@4: 0.833053 valid_0's map@5: 0.823736 valid_0's ndcg@1: 0.866627 valid_0's ndcg@2: 0.87533 valid_0's ndcg@3: 0.867993 valid_0's ndcg@4: 0.859304 valid_0's ndcg@5: 0.854882
[4] valid_0's map@1: 0.872671 valid_0's map@2: 0.863354 valid_0's map@3: 0.844979 valid_0's map@4: 0.835706 valid_0's map@5: 0.830455 valid_0's ndcg@1: 0.857352 valid_0's ndcg@2: 0.867246 valid_0's ndcg@3: 0.86469 valid_0's ndcg@4: 0.860217 valid_0's ndcg@5: 0.859755
[5] valid_0's map@1: 0.878882 valid_0's map@2: 0.867236 valid_0's map@3: 0.847222 valid_0's map@4: 0.837905 valid_0's map@5: 0.832194 valid_0's ndcg@1: 0.866624 valid_0's ndcg@2: 0.874148 valid_0's ndcg@3: 0.866353 valid_0's ndcg@4: 0.863071 valid_0's ndcg@5: 0.861296
[6] valid_0's map@1: 0.875776 valid_0's map@2: 0.858696 valid_0's map@3: 0.844117 valid_0's map@4: 0.836029 valid_0's map@5: 0.829942 valid_0's ndcg@1: 0.860458 valid_0's ndcg@2: 0.86674 valid_0's ndcg@3: 0.86433 valid_0's ndcg@4: 0.86191 valid_0's ndcg@5: 0.859922
[7] valid_0's map@1: 0.869565 valid_0's map@2: 0.857143 valid_0's map@3: 0.843858 valid_0's map@4: 0.836417 valid_0's map@5: 0.829104 valid_0's ndcg@1: 0.857307 valid_0's ndcg@2: 0.86677 valid_0's ndcg@3: 0.864643 valid_0's ndcg@4: 0.861585 valid_0's ndcg@5: 0.858774
[8] valid_0's map@1: 0.872671 valid_0's map@2: 0.858696 valid_0's map@3: 0.844893 valid_0's map@4: 0.83467 valid_0's map@5: 0.828079 valid_0's ndcg@1: 0.860412 valid_0's ndcg@2: 0.868674 valid_0's ndcg@3: 0.865372 valid_0's ndcg@4: 0.860625 valid_0's ndcg@5: 0.858359
[9] valid_0's map@1: 0.869565 valid_0's map@2: 0.856366 valid_0's map@3: 0.842305 valid_0's map@4: 0.832147 valid_0's map@5: 0.827603 valid_0's ndcg@1: 0.857307 valid_0's ndcg@2: 0.86677 valid_0's ndcg@3: 0.863186 valid_0's ndcg@4: 0.858443 valid_0's ndcg@5: 0.857633
[10] valid_0's map@1: 0.86646 valid_0's map@2: 0.860248 valid_0's map@3: 0.844203 valid_0's map@4: 0.832018 valid_0's map@5: 0.827375 valid_0's ndcg@1: 0.854201 valid_0's ndcg@2: 0.868487 valid_0's ndcg@3: 0.863761 valid_0's ndcg@4: 0.85827 valid_0's ndcg@5: 0.857482
[11] valid_0's map@1: 0.86646 valid_0's map@2: 0.860248 valid_0's map@3: 0.840407 valid_0's map@4: 0.832277 valid_0's map@5: 0.825719 valid_0's ndcg@1: 0.854201 valid_0's ndcg@2: 0.868487 valid_0's ndcg@3: 0.860846 valid_0's ndcg@4: 0.858456 valid_0's ndcg@5: 0.856013
[12] valid_0's map@1: 0.878882 valid_0's map@2: 0.858696 valid_0's map@3: 0.840407 valid_0's map@4: 0.835964 valid_0's map@5: 0.832768 valid_0's ndcg@1: 0.866624 valid_0's ndcg@2: 0.867694 valid_0's ndcg@3: 0.861697 valid_0's ndcg@4: 0.861771 valid_0's ndcg@5: 0.860532
[13] valid_0's map@1: 0.875776 valid_0's map@2: 0.859472 valid_0's map@3: 0.842219 valid_0's map@4: 0.83687 valid_0's map@5: 0.833441 valid_0's ndcg@1: 0.863518 valid_0's ndcg@2: 0.866651 valid_0's ndcg@3: 0.861758 valid_0's ndcg@4: 0.86229 valid_0's ndcg@5: 0.860531
[14] valid_0's map@1: 0.885093 valid_0's map@2: 0.864907 valid_0's map@3: 0.848948 valid_0's map@4: 0.843276 valid_0's map@5: 0.838317 valid_0's ndcg@1: 0.872835 valid_0's ndcg@2: 0.871162 valid_0's ndcg@3: 0.867397 valid_0's ndcg@4: 0.867503 valid_0's ndcg@5: 0.864245
[15] valid_0's map@1: 0.885093 valid_0's map@2: 0.869565 valid_0's map@3: 0.852053 valid_0's map@4: 0.845217 valid_0's map@5: 0.840615 valid_0's ndcg@1: 0.872835 valid_0's ndcg@2: 0.874766 valid_0's ndcg@3: 0.869427 valid_0's ndcg@4: 0.868669 valid_0's ndcg@5: 0.865704
[16] valid_0's map@1: 0.885093 valid_0's map@2: 0.871118 valid_0's map@3: 0.854814 valid_0's map@4: 0.846511 valid_0's map@5: 0.840781 valid_0's ndcg@1: 0.872835 valid_0's ndcg@2: 0.875967 valid_0's ndcg@3: 0.871804 valid_0's ndcg@4: 0.869606 valid_0's ndcg@5: 0.865294
[17] valid_0's map@1: 0.885093 valid_0's map@2: 0.867236 valid_0's map@3: 0.851881 valid_0's map@4: 0.845475 valid_0's map@5: 0.838959 valid_0's ndcg@1: 0.872835 valid_0's ndcg@2: 0.872363 valid_0's ndcg@3: 0.869045 valid_0's ndcg@4: 0.869399 valid_0's ndcg@5: 0.864301
[18] valid_0's map@1: 0.885093 valid_0's map@2: 0.868012 valid_0's map@3: 0.855245 valid_0's map@4: 0.847093 valid_0's map@5: 0.840305 valid_0's ndcg@1: 0.872835 valid_0's ndcg@2: 0.873905 valid_0's ndcg@3: 0.872279 valid_0's ndcg@4: 0.869581 valid_0's ndcg@5: 0.865685
[19] valid_0's map@1: 0.888199 valid_0's map@2: 0.869565 valid_0's map@3: 0.859041 valid_0's map@4: 0.847805 valid_0's map@5: 0.842313 valid_0's ndcg@1: 0.87594 valid_0's ndcg@2: 0.87581 valid_0's ndcg@3: 0.875196 valid_0's ndcg@4: 0.870409 valid_0's ndcg@5: 0.867169
[20] valid_0's map@1: 0.885093 valid_0's map@2: 0.870342 valid_0's map@3: 0.860852 valid_0's map@4: 0.850522 valid_0's map@5: 0.843695 valid_0's ndcg@1: 0.875846 valid_0's ndcg@2: 0.875988 valid_0's ndcg@3: 0.877002 valid_0's ndcg@4: 0.873101 valid_0's ndcg@5: 0.86881
LGBMRanker(n_estimators=20, objective='lambdarank')
lright, now we finished the training. It's time to make some predictions about specific customers.
There is one more critical point to notice here. That is, how the model predictions are calculated. I was confused because when calling the .predict
method, it does not expect additional parameters such as group. So I found below information from GitHub issues. According to the creator of LightGBM mentioned in this issue, LightGBM's lambdarank uses a pointwise approach to generate predictions. Meaning we do not want to provide additional parameters such as group. We can call the .predict()
method with the correct feature vector.
However, since we are developing the LTR model, it's essential to have some candidate products and rank those according to the predicted relevance score. To generate candidate animes, I will use a straightforward method here. That is, select the animes that are not exposed to a given user. Select an N random subset from unexposed animes. Generate features based on user and selected N anime subset. Finally, use generated feature vector to get the relevance score and sort animes based on the relevance score. But in real-world use cases, we should use some meaningful methodologies. As an example, you can generate candidates as follows,
- Select the user's favourite N number of genres.
- For each genre in the above-selected genres, pick the highest-rated m animes. Now you have M* N animes to rank for that user. Just create the user base and anime-based features. And finally, call the
.predict()
method with the created feature vector.
user_2_anime_df = relavence_scores.groupby("user_id").agg({"anime_id":lambda x:list(set(x))})
user_2_anime_map = dict(zip(user_2_anime_df.index,user_2_anime_df['anime_id']))
candidate_pool = anime_info_df_final['anime_id'].unique().tolist()
def candidate_generation(user_id,candidate_pool,user_2_anime_map,N):
already_interacted = user_2_anime_map[user_id]
candidates = list(set(candidate_pool) - set(already_interacted))
return already_interacted,np.random.choice(candidates,size=N)
anime_id_2_name = relavence_scores.drop_duplicates(subset=["anime_id","Name"])[['anime_id',"Name"]]
anime_id_2_name_map = dict(zip(anime_id_2_name['anime_id'],anime_id_2_name['Name']))
anime_id_2_name_map[1]
'Cowboy Bebop'
def generate_predictions(user_id,user_2_anime_map,candidate_pool,feature_columns,anime_id_2_name_map,ranker,N=100):
already_liked,candidates = candidate_generation(user_id,candidate_pool,user_2_anime_map,N=10000)
candidates_df = pd.DataFrame(data=pd.Series(candidates,name='anime_id'))
features = anime_info_df_final.merge(candidates_df)
features['user_id'] = user_id
features = features.merge(user_info)
already_liked = list(already_liked)
if len(already_liked) < len(candidates):
append_list = np.full(fill_value=-1,shape=(len(candidates)-len(already_liked)))
already_liked.extend(list(append_list))
predictions = pd.DataFrame(index=candidates)
predictions['name'] = np.array([anime_id_2_name_map.get(id_) for id_ in candidates])
predictions['score'] = ranker.predict(features[feature_columns])
predictions = predictions.sort_values(by='score',ascending=False).head(N)
predictions[f'already_liked - sample[{N}]'] = [anime_id_2_name_map.get(id_) for id_ in already_liked[0:len(predictions)]]
return predictions
print_tabulate(generate_predictions(123,user_2_anime_map,candidate_pool,feature_columns=features,anime_id_2_name_map=anime_id_2_name_map,ranker=model,N=10))
+-------+------------------------------------------------------+---------+------------------------------+
| | name | score | already_liked - sample[10] |
|-------+------------------------------------------------------+---------+------------------------------|
| 20267 | Wooser no Sono Higurashi: Kakusei-hen | 2.14745 | Majo no Takkyuubin |
| 3864 | Souten no Ken Specials | 2.14745 | Tenkuu no Shiro Laputa |
| 3613 | Toshokan Sensou | 2.11424 | Pumpkin Scissors |
| 15915 | Magical Hat | 2.11424 | Omoide Poroporo |
| 31370 | Tonkatsu DJ Agetarou | 2.11424 | Heisei Tanuki Gassen Ponpoko |
| 4810 | Shinzou Ningen Casshern | 2.10967 | Tonari no Totoro |
| 40172 | Lucky☆Orb feat. Hatsune Miku | 2.00188 | Zetsuai 1989 |
| 689 | Uta∽Kata: Shotou no Futanatsu | 1.9301 | Monster |
| 401 | Rurouni Kenshin: Meiji Kenkaku Romantan - Seisou-hen | 1.90252 | xxxHOLiC Kei |
| 41490 | Kai Ten Ryo Kou Ki with Tavito Nanao | 1.90252 | Shounen Onmyouji |
+-------+------------------------------------------------------+---------+------------------------------+
def generate_shap_plots(ranker, X_train, feature_names, N=3):
"""
Generates SHAP plots for a pre-trained LightGBM model.
Parameters:
ranker (lightgbm.Booster): A trained LightGBM model
X_train (np.ndarray): The training data used to fit the model
feature_names (List): list of feature names
N (int): The number of plots to generate
Returns:
None
"""
explainer = shap.Explainer(ranker, X_train, feature_names=feature_names)
shap_values = explainer(X_train.iloc[:N])
# Create a figure with 2 subplots
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,5))
# Plot the summary plot on the first subplot
plt.subplot(1, 2, 1)
shap.summary_plot(shap_values, feature_names=feature_names, plot_type='bar')
# Plot the feature importance plot on the second subplot
plt.subplot(1, 2, 2)
shap.summary_plot(shap_values, feature_names=feature_names, plot_type='dot')
plt.show()
Additionally, we can use shap to explain the model's predictions.
generate_shap_plots(model,xtrain,features,N=10000)