A Machine Learning Assignment Note

A Machine Learning Assignment Note
Photo by Job Savelsberg / Unsplash

Recently I received a machine learning assignment:

The goal of this assignment is to predict if an image is relevant or not, given a text query.

More details about this assignment:

The classical search system takes in a query, generates K candidates from a large database, and refines the K candidates to return a final list of “relevant” results.

In this assignment, you are to work on a text-to-image search system where you are provided with generated candidates for some queries, in which you must provide a solution to classify the irrelevant candidates.

In short, it is a binary classification problem where you should predict if a given candidate is relevant or not to the text query.

You are free to use any external data sources or libraries, but please provide reference if they occupy a main chunk of your solution.

You are allowed to use heuristics, non-ML related techniques, or even exploit leakage (if any). However, hand-labeling is not allowed.

Progress

1 Initial Observations

  1. Certain columns exhibit skewness, so we plan to exclude them during training or serving. These columns include: crossorigin, height, ismap, loading, longdesc, referrerpolicy, sizes, srcset, width, class, and style.
  2. The queries have been concatenated.
  3. The columns titled 'title', 'alt', and 'text' contain textual information.
  4. We will frame this task as a binary classification problem. The objective is to predict, based on the queries and additional features, whether an image is relevant to the provided query.

2 Insights from the Baseline Model

  1. We removed the skewed columns.
  2. Utilizing heuristics, we performed some feature engineering, as detailed in the subsequent table.
  3. Having defined this task as a binary classification, we opted for XGBoost to assess the heuristic features due to its simplicity and efficiency.
  4. By using XGBoost for binary classification with the selected features, we established a baseline model. The AUC for this base model stands at 0.642.
Engineered features Notes Type
in_title if the query is in the title boolean
in_alt if the query is in the alt boolean
in_text if the query is in the text boolean
same_source if the image is from the same site boolean
text_tree_len the length of text tree int

3 Observations from embeddings

  1. We aim to further use these textual features, such as the similarity between queries and the "titles" and "alt" columns, to see if there's an enhancement in the AUC score.
  2. We produced intermediate features from the embeddings of queries, titles, and alts. Later, we calculated the cosine distance between the query embedding and title embedding, as well as between the query embedding and alt embedding. These distances were then used as classification features (refer to the table below).
  3. Initially, we employed the Bert embedding, which resulted in an AUC of 0.592, a decline compared to the heuristic approach.
Embedding Features Training sample size AUC nfold boost round
bert-base-nli-mean-tokens similarities (alt, title) + heuristics 7,500 0.592 5 1000
Engineered Features notes Type
query_embed convert the query into an embedding tensor (not included in classification)
title_embed convert the title into an embedding tensor (not included in classification)
text_embed convert the text into an embedding tensor (not included in classification)
sim_q_title cos similarity between query embedding and title embedding float
sim_q_alt cos similarity between query embedding and alt embedding float

4 Observations from more embeddings

  1. We explored various embeddings to test the efficacy of this approach.
  2. By utilizing pretrained models from the sentence_transformers package, we observed a notable improvement in performance.
  3. All subsequent tests were conducted with a 5-fold cross validation and 1000 boost rounds.
  4. Transitioning from Bert to all-MiniLM-L6-v2, and then to all-mpnet-base-v2, the AUC rose from 0.592 to 0.708 and further to 0.723.
  5. Additionally, increasing the training sample size positively impacted the AUC: with all-MiniLM-L6-v2, the AUC went up from 0.708 to 0.751 to 0.765 as the sample size grew from 10,000 to 100,000 and finally to 600,000.
  6. The best performance was achieved using all-mpnet-base-v2 with 600,000 samples, reaching an AUC of 0.769.
  7. Omitting the heuristics and solely relying on the embedding features, there was a minor increase in AUC from 0.708 to 0.733 with all-mpnet-base-v2, and a slight decrease from 0.769 to 0.762 with all-mpnet-base-v2.
Embedding Features Training sample size AUC nfold boost round
bert-base-nli-mean-tokens similarities (alt, title) + heuristics 7,500 0.592 5 1000
all-MiniLM-L6-v2 similarities (alt, title) + heuristics 7,500 0.708 5 1000
all-MiniLM-L6-v2 similarities (alt, title) + heuristics 75,000 0.751 5 1000
all-MiniLM-L6-v2 similarities (alt, title) + heuristics 450,000 0.765 5 1000
all-MiniLM-L6-v2 similarities (alt, title) 7,500 0.733 5 1000
all-mpnet-base-v2 similarities (alt, title) 450,000 0.726 5 1000
all-mpnet-base-v2 similarities (alt, title) + heuristics 7,500 0.723 5 1000
all-mpnet-base-v2 similarities (alt, title) + heuristics 450,000 0.769 5 1000

5 Performance improvement

  1. In our earlier experiments, generating embeddings using sentence-transformers proved to be a bottleneck. This was primarily because we employed the pandas.Series.apply() method to apply embeddings row by row.
  2. However, after coming across a helpful Reddit post, we revamped our code to leverage the vectorization capabilities of pandas. The resulting performance boost was substantial.
Sample size CPU or GPU User time System time Total time Notes
1,000 CPU 27.2 s 12.9 s 40.1 s apply()
1,000 GPU 45.4 s 2.86 s 48.3 s apply()
10,000 CPU 4min 31s 2min 3s 6min 35s apply()
10,000 GPU 7min 5s 23.8 s 7min 28s apply()
10,000 GPU 22.2 s 4.72 s 26.9 s vectorization
100,000 GPU 3min 5s 16.2 s 3min 21s vectorization
600,000 GPU 21min 25s 11min 3s 32min 28s vectorization
600,000 GPU 45min 14s 25min 42s 1h 10min 56s vectorization

# embedding code with pandas.Series.apply()

def embed_values(column: str):
    if pd.notna(column):
        return MODEL_EMBED.encode(column)
    else:
        return None

def add_embedding(df: pd.DataFrame, new_column: str, column_to_apply: str) -> pd.DataFrame:
    df.loc[:, new_column] = df[column_to_apply].apply(embed_values)
    return df 
# embedding code with pandas vectorization

def embed_speedup(df: pd.DataFrame, col:str, embeded_col:str):
    df[col] =  df[col].fillna('')
    embeded = MODEL_EMBED.encode(df[col].tolist())
    df[embeded_col] = embeded.tolist()
    return df

6 Adding more embeddings

  1. Upon revisiting the work carried out in the past few days, we realized that we had overlooked the embedding of the text column. After embedding this column and calculating the cosine distance between queries and texts, we achieved a further improvement in the AUC score.
  2. The highest AUC score attained was 0.790, utilizing the all-mpnet-base-v2 embedding.
Embedding Features Training sample size AUC nfold boost round
all-MiniLM-L6-v2 similarities (alt, title, text) + heuristics 7,500 0.729 5 1000
all-MiniLM-L6-v2 similarities (alt, title, text) + heuristics 75,000 0.770 5 1000
all-MiniLM-L6-v2 similarities (alt, title, text) + heuristics 450,000 0.788 5 1000
all-mpnet-base-v2 similarities (alt, title, text) + heuristics 450,000 0.790 5 1000
Engineered Features notes Type
query_embed convert the query into an embedding tensor (not included in classification)
title_embed convert the title into an embedding tensor (not included in classification)
alt_embed convert the alt into an embedding tensor (not included in classification)
text_embed convert the text into an embedding tensor (not included in classification)
sim_q_title cos similarity between query embedding and title embedding float
sim_q_alt cos similarity between query embedding and alt embedding float
sim_q_text cos similarity between query embedding and text embedding float

7 Running on tests

  1. When executing the classifier on the test set, a noticeable drop in AUC score was observed, plummeting from 0.790 during training to 0.501 in testing.
  2. This significant reduction can likely be attributed to overfitting. In future, we would explore more on how to overcome overfitting in this task.

Final Reflections

  1. Due to the constraints of time, we decided against crawling for images. Instead, we concentrated on utilizing available features to engineer useful ones for the classifier.
  2. We framed this challenge as a binary classification problem, opting for XGBoost because of its simplicity and efficiency.
  3. Heuristics were employed to lay the groundwork. While the outcome wasn't exactly as anticipated, this swift measure set a baseline. When combined with embeddings, there was a subtle enhancement to the overall results.
  4. Recognizing that certain queries were concatenated, embeddings were used to seize underlying features. We then gauged the similarities between embeddings of queries, text, alt, and title. This approach proved effective.
  5. Undertaking experiments with varied embeddings contributed to an improved final AUC score. Using BERT embedding as the foundation, there was a slight dip when comparing with the heuristic baseline.
  6. Drawing inspiration from sentence-transformers, we uncovered more potent embeddings, namely all-MiniLM-L6-v2 and all-mpnet-base-v2.
  7. A balancing act was required between embedding quality and speed. At the onset, we grappled with performance-related concerns, making all-MiniLM-L6-v2 our primary choice for trials. This embedding markedly enhanced the classification outcomes.
  8. Later on, we leveraged Pandas vectorization (bolstered by GPU) to improve the embedding performance. Consequently, we integrated all-mpnet-base-v2 to push the classification outcomes even further.
  9. Running the meticulously trained classifier on the test dataset unveiled a pronounced overfitting issue. If given more time, a deeper exploration into this would be on our agenda. Probable culprits for overfitting might be:
    1. A starkly unbalanced dataset.
    2. Concatenated queries that lost their inherent significance.
  10. Given an extension in the timeline to address overfitting, our course of action would be to:
    1. Use data augmentation techniques to generate more positive samples, aiming for a more equilibrium in the dataset.
    2. Parse and disentangle concatenated queries to retrieve their inherent meaning.
    3. Venture into image crawling and deploy multimodal models to extract vital image features.
  11. Recognizing the reusability of the data preprocessing code for both training and prediction phases, we modularized it into a separate entity: data_processing.py.