A Machine Learning Assignment Note
Recently I received a machine learning assignment:
The goal of this assignment is to predict if an image is relevant or not, given a text query.
More details about this assignment:
The classical search system takes in a query, generates K candidates from a large database, and refines the K candidates to return a final list of “relevant” results.
In this assignment, you are to work on a text-to-image search system where you are provided with generated candidates for some queries, in which you must provide a solution to classify the irrelevant candidates.
In short, it is a binary classification problem where you should predict if a given candidate is relevant or not to the text query.
You are free to use any external data sources or libraries, but please provide reference if they occupy a main chunk of your solution.
You are allowed to use heuristics, non-ML related techniques, or even exploit leakage (if any). However, hand-labeling is not allowed.
Progress
1 Initial Observations
- Certain columns exhibit skewness, so we plan to exclude them during training or serving. These columns include: crossorigin, height, ismap, loading, longdesc, referrerpolicy, sizes, srcset, width, class, and style.
- The queries have been concatenated.
- The columns titled 'title', 'alt', and 'text' contain textual information.
- We will frame this task as a binary classification problem. The objective is to predict, based on the queries and additional features, whether an image is relevant to the provided query.
2 Insights from the Baseline Model
- We removed the skewed columns.
- Utilizing heuristics, we performed some feature engineering, as detailed in the subsequent table.
- Having defined this task as a binary classification, we opted for XGBoost to assess the heuristic features due to its simplicity and efficiency.
- By using XGBoost for binary classification with the selected features, we established a baseline model. The AUC for this base model stands at 0.642.
Engineered features | Notes | Type |
---|---|---|
in_title | if the query is in the title | boolean |
in_alt | if the query is in the alt | boolean |
in_text | if the query is in the text | boolean |
same_source | if the image is from the same site | boolean |
text_tree_len | the length of text tree | int |
3 Observations from embeddings
- We aim to further use these textual features, such as the similarity between queries and the "titles" and "alt" columns, to see if there's an enhancement in the AUC score.
- We produced intermediate features from the embeddings of queries, titles, and alts. Later, we calculated the cosine distance between the query embedding and title embedding, as well as between the query embedding and alt embedding. These distances were then used as classification features (refer to the table below).
- Initially, we employed the Bert embedding, which resulted in an AUC of 0.592, a decline compared to the heuristic approach.
Embedding | Features | Training sample size | AUC | nfold | boost round |
---|---|---|---|---|---|
bert-base-nli-mean-tokens | similarities (alt, title) + heuristics | 7,500 | 0.592 | 5 | 1000 |
Engineered Features | notes | Type |
---|---|---|
query_embed | convert the query into an embedding | tensor (not included in classification) |
title_embed | convert the title into an embedding | tensor (not included in classification) |
text_embed | convert the text into an embedding | tensor (not included in classification) |
sim_q_title | cos similarity between query embedding and title embedding | float |
sim_q_alt | cos similarity between query embedding and alt embedding | float |
4 Observations from more embeddings
- We explored various embeddings to test the efficacy of this approach.
- By utilizing pretrained models from the sentence_transformers package, we observed a notable improvement in performance.
- All subsequent tests were conducted with a 5-fold cross validation and 1000 boost rounds.
- Transitioning from Bert to all-MiniLM-L6-v2, and then to all-mpnet-base-v2, the AUC rose from 0.592 to 0.708 and further to 0.723.
- Additionally, increasing the training sample size positively impacted the AUC: with all-MiniLM-L6-v2, the AUC went up from 0.708 to 0.751 to 0.765 as the sample size grew from 10,000 to 100,000 and finally to 600,000.
- The best performance was achieved using all-mpnet-base-v2 with 600,000 samples, reaching an AUC of 0.769.
- Omitting the heuristics and solely relying on the embedding features, there was a minor increase in AUC from 0.708 to 0.733 with all-mpnet-base-v2, and a slight decrease from 0.769 to 0.762 with all-mpnet-base-v2.
Embedding | Features | Training sample size | AUC | nfold | boost round |
---|---|---|---|---|---|
bert-base-nli-mean-tokens | similarities (alt, title) + heuristics | 7,500 | 0.592 | 5 | 1000 |
all-MiniLM-L6-v2 | similarities (alt, title) + heuristics | 7,500 | 0.708 | 5 | 1000 |
all-MiniLM-L6-v2 | similarities (alt, title) + heuristics | 75,000 | 0.751 | 5 | 1000 |
all-MiniLM-L6-v2 | similarities (alt, title) + heuristics | 450,000 | 0.765 | 5 | 1000 |
all-MiniLM-L6-v2 | similarities (alt, title) | 7,500 | 0.733 | 5 | 1000 |
all-mpnet-base-v2 | similarities (alt, title) | 450,000 | 0.726 | 5 | 1000 |
all-mpnet-base-v2 | similarities (alt, title) + heuristics | 7,500 | 0.723 | 5 | 1000 |
all-mpnet-base-v2 | similarities (alt, title) + heuristics | 450,000 | 0.769 | 5 | 1000 |
5 Performance improvement
- In our earlier experiments, generating embeddings using sentence-transformers proved to be a bottleneck. This was primarily because we employed the
pandas.Series.apply()
method to apply embeddings row by row. - However, after coming across a helpful Reddit post, we revamped our code to leverage the vectorization capabilities of pandas. The resulting performance boost was substantial.
Sample size | CPU or GPU | User time | System time | Total time | Notes |
---|---|---|---|---|---|
1,000 | CPU | 27.2 s | 12.9 s | 40.1 s | apply() |
1,000 | GPU | 45.4 s | 2.86 s | 48.3 s | apply() |
10,000 | CPU | 4min 31s | 2min 3s | 6min 35s | apply() |
10,000 | GPU | 7min 5s | 23.8 s | 7min 28s | apply() |
10,000 | GPU | 22.2 s | 4.72 s | 26.9 s | vectorization |
100,000 | GPU | 3min 5s | 16.2 s | 3min 21s | vectorization |
600,000 | GPU | 21min 25s | 11min 3s | 32min 28s | vectorization |
600,000 | GPU | 45min 14s | 25min 42s | 1h 10min 56s | vectorization |
# embedding code with pandas.Series.apply()
def embed_values(column: str):
if pd.notna(column):
return MODEL_EMBED.encode(column)
else:
return None
def add_embedding(df: pd.DataFrame, new_column: str, column_to_apply: str) -> pd.DataFrame:
df.loc[:, new_column] = df[column_to_apply].apply(embed_values)
return df
# embedding code with pandas vectorization
def embed_speedup(df: pd.DataFrame, col:str, embeded_col:str):
df[col] = df[col].fillna('')
embeded = MODEL_EMBED.encode(df[col].tolist())
df[embeded_col] = embeded.tolist()
return df
6 Adding more embeddings
- Upon revisiting the work carried out in the past few days, we realized that we had overlooked the embedding of the text column. After embedding this column and calculating the cosine distance between queries and texts, we achieved a further improvement in the AUC score.
- The highest AUC score attained was 0.790, utilizing the all-mpnet-base-v2 embedding.
Embedding | Features | Training sample size | AUC | nfold | boost round |
---|---|---|---|---|---|
all-MiniLM-L6-v2 | similarities (alt, title, text) + heuristics | 7,500 | 0.729 | 5 | 1000 |
all-MiniLM-L6-v2 | similarities (alt, title, text) + heuristics | 75,000 | 0.770 | 5 | 1000 |
all-MiniLM-L6-v2 | similarities (alt, title, text) + heuristics | 450,000 | 0.788 | 5 | 1000 |
all-mpnet-base-v2 | similarities (alt, title, text) + heuristics | 450,000 | 0.790 | 5 | 1000 |
Engineered Features | notes | Type |
---|---|---|
query_embed | convert the query into an embedding | tensor (not included in classification) |
title_embed | convert the title into an embedding | tensor (not included in classification) |
alt_embed | convert the alt into an embedding | tensor (not included in classification) |
text_embed | convert the text into an embedding | tensor (not included in classification) |
sim_q_title | cos similarity between query embedding and title embedding | float |
sim_q_alt | cos similarity between query embedding and alt embedding | float |
sim_q_text | cos similarity between query embedding and text embedding | float |
7 Running on tests
- When executing the classifier on the test set, a noticeable drop in AUC score was observed, plummeting from 0.790 during training to 0.501 in testing.
- This significant reduction can likely be attributed to overfitting. In future, we would explore more on how to overcome overfitting in this task.
Final Reflections
- Due to the constraints of time, we decided against crawling for images. Instead, we concentrated on utilizing available features to engineer useful ones for the classifier.
- We framed this challenge as a binary classification problem, opting for XGBoost because of its simplicity and efficiency.
- Heuristics were employed to lay the groundwork. While the outcome wasn't exactly as anticipated, this swift measure set a baseline. When combined with embeddings, there was a subtle enhancement to the overall results.
- Recognizing that certain queries were concatenated, embeddings were used to seize underlying features. We then gauged the similarities between embeddings of queries, text, alt, and title. This approach proved effective.
- Undertaking experiments with varied embeddings contributed to an improved final AUC score. Using BERT embedding as the foundation, there was a slight dip when comparing with the heuristic baseline.
- Drawing inspiration from sentence-transformers, we uncovered more potent embeddings, namely all-MiniLM-L6-v2 and all-mpnet-base-v2.
- A balancing act was required between embedding quality and speed. At the onset, we grappled with performance-related concerns, making all-MiniLM-L6-v2 our primary choice for trials. This embedding markedly enhanced the classification outcomes.
- Later on, we leveraged Pandas vectorization (bolstered by GPU) to improve the embedding performance. Consequently, we integrated all-mpnet-base-v2 to push the classification outcomes even further.
- Running the meticulously trained classifier on the test dataset unveiled a pronounced overfitting issue. If given more time, a deeper exploration into this would be on our agenda. Probable culprits for overfitting might be:
- A starkly unbalanced dataset.
- Concatenated queries that lost their inherent significance.
- Given an extension in the timeline to address overfitting, our course of action would be to:
- Use data augmentation techniques to generate more positive samples, aiming for a more equilibrium in the dataset.
- Parse and disentangle concatenated queries to retrieve their inherent meaning.
- Venture into image crawling and deploy multimodal models to extract vital image features.
- Recognizing the reusability of the data preprocessing code for both training and prediction phases, we modularized it into a separate entity:
data_processing.py.