GCP – Fine-tuning Gemma, the journey from beginning to end
Chatbots are one of the more common, early use cases for generative AI, particularly in retail organizations. To make them useful for shoppers, a chatbot needs to be contextually sensitive to a retailer’s product catalog, with the ability to respond to customer inquiries conversationally, perhaps suggesting alternatives and complementary outfits. To accomplish this, an instruction-tuned base model can help with the conversation, but fine-tuning will be required to both include retailer product catalog data and respond in an instructed way. This blog series will describe the process for fine-tuning a Gemma instruction-tuned model, utilizing the following approach:
Preparing the data set
Fine-tuning the instruction-tuned model
Using hyperparameter tuning to generate multiple iterations of the model
Validating, evaluating and identifying the most ideal model
As we walk through the process, in this first blog we will share methodologies, frameworks, tools and lessons learned for fine-tuning a base model to help you on your journey. The examples discussed are available in the following repository.
Preparing the dataset
The dataset we are going to use is a pre-crawled subset of a larger dataset (consisting of over 5.8 million products) that was created by extracting data from Flipkart, a prominent Indian e-commerce store. Open-source dataset: Flipkart Product Catalog.
Dataset Preparation Methodology:
The dataset underwent a filtering process to concentrate exclusively on clothing items, for women, men, and children. Products with less than ten units in a specific category were excluded.
To optimize instruction-tuned models, it is imperative that each data point consists of three distinct components: Question, Answer and Context
Gemini 1.5 Flash on Vertex AI was used for the generation of natural shopping-style conversational questions (inputs) and answers (outputs). The formulation of the Input and Output was based on the product specifications detailed within the catalog.
The prompts used in fine-tuning were created by combining the Gemini Flash 1.5 generated questions and answers. The product category is used to set the context of the conversation, as shown in the prompt sample:
<ListValue: [StructValue([(‘code’, “<start_of_turn>userrn Context:Online shopping for Men’s ClothingrnWhat kind of material is the SKOOKIE Sleeveless Solid Men’s Jacket made of?<end_of_turn> <start_of_turn>modelrnThe SKOOKIE Sleeveless Solid Men’s Jacket is made of a blinded fabric. rnProduct Name: SKOOKIE Sleeveless Solid Men’s JacketrnProduct Category: Men’s ClothingrnProduct Details:rn- Sleeve: Sleevelessrn- Reversible: Norn- Fabric: Blindedrn- Pattern: Solidrn- Ideal For: Men’srn- Style Code: SKMJ-3005-C-Bluern<end_of_turn>”), (‘language’, ”), (‘caption’, <wagtail.rich_text.RichText object at 0x3ead7fd683d0>)])]>
The ‘End Of Sequence’ token was appended to each prompt.
EOS_TOKEN = ‘<eos>’
Fine-tuning
To fine-tune Gemma for a chat-style experience with pytorch, we used Low-Rank Adaptation (LoRA), a Parameter-Efficient Fine-Tuning (PEFT) technique with Transformer Reinforcement Learning’s (TRL) Supervised Fine Tuning (SFT). We experimented with both pre-trained and instruction-tuned base models, ultimately finding that an instruction-tuned model, paired with its tailored dataset, best aligned with our desired conversational outcomes.
Early trials involved quantized int4 models on both T4 on Google Colab notebooks and L4 GPUs on Google Kubernetes Engine (GKE), which, while yielding promising results on single GPUs, were limited by the 7b model’s large GPU memory footprint (40GB+). This quantization approach enabled quick, cost-effective initial experiments that guided our fine-tuning strategy.
Scaling up with multi-GPU and multi-host environments
For large scale fine-tuning, we opted for non-quantized models in a multi-GPU, multi-host configuration. We leveraged Accelerate with Flash Attention-2 for single and multi-GPU orchestration and Fully Sharded Data Parallelism (FSDP) for model parameter sharding across GPUs. This configuration, while powerful, required extensive testing and iteration to optimize.
Accelerate is a library that is sensitive to configuration settings, which can sometimes lead to issues. Enabling debugging in accelerate is intended only as a wrapper to surface pytorch code failures, so we weren’t getting errors about our specific invalid configurations. When the fine-tuning process is kicked off from the deployed k8s Job on GKE with `accelerate launch`, we ended up with errors on trainer.train(). This could go unnoticed and if you are not lucky with a quick fail, we experienced 60min default timeouts.
Crucially, we discovered the importance of accurately specifying num_machines (number of physical machines) and num_processes (total GPUs used) to prevent trainer.train() timeouts. GKE provisioned the desired resources through resource requests, we mapped it to the accelerate configuration.
While num_machines is straightforward, num_processes is actually the total amount of GPUs to be used in the job (i.e. 2 machines with 4 GPUs each, would be num_machines=2 and num_processes=8).
Overcoming dataset and configuration challenges
Initially, Gemma frequently responded with repetitive <pad> tokens, indicating a failure to generate meaningful output. To solve this, we set padding side to right, pad token to the eos token and appended the eos token to each row in the dataset. This issue was resolved by refining the dataset. We discovered that fine-tuning on individual prompts, rather than context-input pairs, was the key. The prompt processing was implemented using the dataset map function in a batch form.
The combination of non-quantized + Accelerate + FSDP on Gemma required specific settings for the configuration parameters used, which involved some trial and error. We had to run accelerate config and respond to wizard prompts, which consisted of questions (ex: distribution type, FSDP and precision level) to achieve desired results. Using the wizard allowed for step by step configuration and determining desired settings, resulting in not having <pad> results.
The last part of our fine-tuning involved saving our fine-tuned model weights and the tokenizer. Since Accelerate distributes the work across GPUs on a single host and other worker hosts, we will need to leverage a very important function is_main_process which ensures that only the main leader process saves the aggregate model weights and tokenizer.
Based on your model and expectations, the dataset structure and contents are critical. As noted earlier, we had to adjust our data set to fine-tune the model to achieve our desired response format.
Hyperparameter tuning
Fine-tuning large language models like Gemma also involves carefully adjusting various hyperparameters to optimize performance and tailor the model’s behavior for specific tasks.
Not all parameters are created equal and will vary by model. However, let’s walk through some patterns which can help potentially determine a fine-tuned model which is sufficient for the use case.
In our fine-tuning process, we focused on the following key parameters:
Core LoRA Parameters
lora_r (Rank): This parameter controls the rank of the low-rank matrices used in LoRA. A higher rank increases the expressiveness of the model but also increases the number of trainable parameters.
lora_alpha (Alpha): This scaling factor balances the contribution of the LoRA updates to the original model weights. A larger alpha can lead to faster adaptation but may also introduce instability.
lora_dropout (Dropout): Dropout is a regularization technique that helps prevent overfitting. Applying dropout to LoRA layers can further improve generalization.
Learning rate,optimization, and training duration
learning_rate (Learning Rate): The learning rate determines the step size the optimizer takes during training. Finding the optimal learning rate is crucial for efficient convergence and avoiding overshooting.
epochs (Number of Epochs): The number of epochs dictates how many times the model sees the entire training dataset. More epochs generally lead to better performance, but overfitting can become a concern.
max_grad_norm (Maximum Gradient Norm): Gradient clipping helps prevent exploding gradients, which can destabilize training. This parameter limits the maximum norm of the gradients.
weight_decay (Weight Decay): Weight decay is a regularization technique that adds a penalty to the loss function based on the magnitude of the model weights. This helps prevent overfitting.
warmup_ratio (Warmup Ratio): This parameter determines the proportion of the total training steps during which the learning rate is gradually increased from zero to its initial value. Warmup can improve stability and convergence.
Sequence length considerations
max_seq_length (Maximum Sequence Length): This parameter controls the maximum length of the input sequences the model can process. Longer sequences can provide more context but also require more computational resources.
The following table displays the experimental values we used for the aforementioned parameters:
parameter
base
job-0
job-1
job-2
epochs
1
2
3
4
lora_r
8
8
16
32
lora_alpha
16
16
32
64
lora_dropout
0.1
0.1
0.2
0.3
max_grad_norm
0.3
1
1
1
learning_rate
2.00E+04
2.00E+05
2.00E+04
3.00E+04
weight_decay
0.001
0.01
0.005
0.001
warmup_ratio
0.03
0.1
0.2
0.3
max_seq_length
512
1024
2048
8192
Parameters play a critical role in determining the accuracy of the model, and there isn’t a one size fits all as the data set and the model can influence the outcome.
Wrap-up
This has been a fun journey filled with a variety of goals and tasks to fine-tune our Gemma Instruction-tuned model. This exploration reinforced the importance of data preparation and hyperparameter optimization in the fine-tuning process. Fine-tuning parameters have a significant effect resulting in the model responding differently, from incoherent responses to structured responses with our data set. It is important to ensure the chatbot understands nuanced language, handles complex dialogues, and delivers accurate responses.
Stay tuned to this fine-tuning series as we will cover the observability of our jobs through MLflow, which is essential for maximizing model effectiveness.
Read More for the details.