GCP – A step-by-step guide to fine-tuning MedGemma for breast tumor classification
- aside_block
- <ListValue: [StructValue([(‘title’, ‘Disclaimer: This guide is for informational and educational purposes only and is not a substitute for professional medical advice, diagnosis, or treatment.’), (‘body’, <wagtail.rich_text.RichText object at 0x7f7c104adc40>), (‘btn_text’, ”), (‘href’, ”), (‘image’, None)])]>
Artificial intelligence (AI) is revolutionizing healthcare, but how do you take a powerful, general-purpose AI model and teach it the specialized skills of a pathologist? This journey from prototype to production often begins in a notebook, which is exactly where we’ll start.
In this guide, we’ll take the crucial first step. We’ll walk through the complete process of fine-tuning the Gemma 3 variant MedGemma. MedGemma is Google’s family of open models for the medical community, to classify breast cancer histopathology images. We’re using the full precision MedGemma model because that’s what you’ll need in order to get maximum performance for many clinical tasks. If you’re concerned about compute costs, you can quantize and fine-tune by using MedGemma’s pre-configured fine-tuning notebook instead.
To complete our first step, we’ll use the Finetune Notebook. The notebook provides you with all of the code and a step-by-step explanation of the process, so it’s the perfect environment for experimentation. I’ll also share the key insights that I learned along the way, including a critical choice in data types that made all the difference.
After we’ve perfected our model in this prototyping phase, we’ll be ready for the next step. In an upcoming post, we’ll show you how to take this exact workflow and move it to a scalable, production-ready environment using Cloud Run jobs.
Setting the stage: Our goal, model, and data
Before we get to the code, let’s set the stage. Our goal is to classify microscope images of breast tissue into one of eight categories: four benign (non-cancerous) and four malignant (cancerous). This type of classification represents one of many crucial tasks that pathologists perform in order to make an accurate diagnosis, and we have a great set of tools for the job.
We’ll be using MedGemma, a powerful family of open models from Google that’s built on the same research and technology that powers our Gemini models. What makes MedGemma special is that it isn’t just a general model: it’s been specifically tuned for the medical domain.
The MedGemma vision component, MedSigLIP, was pre-trained on a vast amount of de-identified medical imagery, including the exact type of histopathology slides that we’re using. If you don’t need the predictive power of MedGemma, you can use MedSigLIP alone as a more cost-effective option for predictive tasks like image classification. There are multiple MedSigLIP tutorial notebooks that you can use for fine-tuning.
The MedGemma language component was also trained on a diverse set of medical texts, making the google/medgemma-4b-it version that we’re using perfect for following our text-based prompts. Google provides MedGemma as a strong foundation, but it requires fine-tuning for specific use cases—which is exactly what we’re about to do.
To train our model, we’ll use the Breast Cancer Histopathological Image Classification (BreakHis) dataset. The BreakHis dataset is a public collection of thousands of microscope images of breast tumor tissue that was collected from 82 patients using different magnifying factors (40X, 100X, 200X, and 400X). The dataset is publicly available for non-commercial research and it’s detailed in the paper: F. A. Spanhol, L. S. Oliveira, C. Petitjean, and L. Heudel, A dataset for breast cancer histopathological image classification.1
Handling a 4-billion parameter model requires a capable GPU, so I used an NVIDIA A100 with 40 GB of VRAM on Vertex AI Workbench. This GPU has the necessary power, and it also features NVIDIA Tensor Cores that excel with modern data formats, which we’ll leverage for faster training. In an upcoming post, we’ll explain how to calculate the VRAM that’s required for your fine tuning.
My float16 disaster: A crucial lesson in stability
My first attempt to load the model used the common float16 data type to save memory. It failed spectacularly. The model’s outputs were complete garbage, and a quick debugging check revealed that every internal value had collapsed into NaN (Not a Number).
The culprit was a classic numerical overflow.
To understand why, you need to know the critical difference between these 16-bit formats:
-
float16 (FP16): Has a tiny numerical range. It can’t represent any number that’s greater than 65,504. During the millions of calculations in a transformer, intermediate values can easily exceed this limit, causing an overflow that creates a NaN. When a NaN appears, it contaminates every subsequent calculation.
-
bfloat16 (BF16): This format, developed at Google Brain, makes a crucial trade-off. It sacrifices a little bit of precision to maintain the same massive numerical range as the full 32-bit float32 format.
The bfloat16 massive range prevents overflows, which keeps the training process stable. The fix was a simple one-line change, but it was based on this critical concept.
The successful code:
- code_block
- <ListValue: [StructValue([(‘code’, ‘# The simple, stable solutionrnmodel_kwargs = dict(rn torch_dtype=torch.bfloat16, # Use bfloat16 for its wide numerical rangern device_map=”auto”,rn attn_implementation=”sdpa”,rn)rnrnmodel = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)’), (‘language’, ‘lang-py’), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7c11187d30>)])]>
Lesson learned: For fine-tuning large models, always prefer bfloat16 for its stability. It’s a small change that saves you from a world of NaN-related headaches.
The code walkthrough: A step-by-step guide
Now, let’s get to the code. I’ll break down my Finetune Notebook into clear, logical steps.
Step 1: Setup and installations
First, you need to install the necessary libraries from the Hugging Face ecosystem and log into your account to download the model.
- code_block
- <ListValue: [StructValue([(‘code’, ‘# Install required packagesrn!pip install –upgrade –quiet transformers datasets evaluate peft trl scikit-learnrnrnimport osrnimport rernimport torchrnimport gcrnfrom datasets import load_dataset, ClassLabelrnfrom peft import LoraConfig, PeftModelrnfrom transformers import AutoModelForImageTextToText, AutoProcessorrnfrom trl import SFTTrainer, SFTConfigrnimport evaluate’), (‘language’, ‘lang-py’), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7c11187ac0>)])]>
Hugging Face authentication and and the recommended approach to handle your secrets
⚠️ Important security note: You should never hardcode secrets like API keys or tokens directly into your code or notebooks, especially in a production environment. This practice is insecure and it creates a significant security risk.
In Vertex AI Workbench, the most secure and enterprise-grade approach to handle secrets (like your Hugging Face token) is to use Google Cloud’s Secret Manger.
If you’re just experimenting and you don’t want to set up Secret Manager yet, you can use the interactive login widget. The widget saves the token temporarily in the instance’s file system.
- code_block
- <ListValue: [StructValue([(‘code’, ‘# Hugging Face authentication using interactive login widget:rnfrom huggingface_hub import notebook_loginrnnotebook_login()’), (‘language’, ”), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7c11187b50>)])]>
In our upcoming post, where we move this process to Cloud Run Jobs, we’ll show you the correct and secure way to handle this token by using Secret Manager.
Step 2: Load and prepare the dataset
Next, we download the BreakHis dataset from Kaggle using the kagglehub library. This dataset includes a Folds.csv file, which outlines how the data is split for experiments. The original study used 5-fold cross-validation, but to keep the training time manageable for this demonstration, we’ll focus on Fold 1 and we’ll only use images with 100X magnification. You can explore using other folds and magnifications for more extensive experiments.
- code_block
- <ListValue: [StructValue([(‘code’, ‘! pip install -q kagglehubrnimport kagglehubrnimport osrnimport pandas as pdrnfrom PIL import Imagernfrom datasets import Dataset, Image as HFImage, Features, ClassLabelrnrn# Download the dataset metadatarnpath = kagglehub.dataset_download(“ambarish/breakhis”)rnprint(“Path to dataset files:”, path)rnfolds = pd.read_csv(‘{}/Folds.csv’.format(path))rnrn# Filter for 100X magnification from the first foldrnfolds_100x = folds[folds[‘mag’]==100]rnfolds_100x = folds_100x[folds_100x[‘fold’]==1]rnrn# Get the train/test splitsrnfolds_100x_test = folds_100x[folds_100x.grp==’test’]rnfolds_100x_train = folds_100x[folds_100x.grp==’train’]rnrn# Define the base path for imagesrnBASE_PATH = “/home/jupyter/.cache/kagglehub/datasets/ambarish/breakhis/versions/4/BreaKHis_v1″‘), (‘language’, ‘lang-py’), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7c111879d0>)])]>
Step 2.1: Balance the dataset
The initial train and test splits for the 100X magnification show an imbalance between benign and malignant classes. To address this, we’ll undersample the majority class in both the training and testing sets in order to create balanced datasets with a 50/50 distribution.
- code_block
- <ListValue: [StructValue([(‘code’, ‘# — 1. Create Balanced TRAIN Set —rntrain_benign_df = folds_100x_train[folds_100x_train[‘filename’].str.contains(‘benign’)]rntrain_malignant_df = folds_100x_train[folds_100x_train[‘filename’].str.contains(‘malignant’)]rnmin_train_count = min(len(train_benign_df), len(train_malignant_df))rnbalanced_train_benign = train_benign_df.sample(n=min_train_count, random_state=42)rnbalanced_train_malignant = train_malignant_df.sample(n=min_train_count, random_state=42)rnbalanced_train_df = pd.concat([balanced_train_benign, balanced_train_malignant])rnrn# — 2. Create Balanced TEST Set —rntest_benign_df = folds_100x_test[folds_100x_test[‘filename’].str.contains(‘benign’)]rntest_malignant_df = folds_100x_test[folds_100x_test[‘filename’].str.contains(‘malignant’)]rnmin_test_count = min(len(test_benign_df), len(test_malignant_df))rnbalanced_test_benign = test_benign_df.sample(n=min_test_count, random_state=42)rnbalanced_test_malignant = test_malignant_df.sample(n=min_test_count, random_state=42)rnbalanced_test_df = pd.concat([balanced_test_benign, balanced_test_malignant])rnrn# — 3. Get the Final Filename Lists —rntrain_filenames = balanced_train_df[‘filename’].valuesrntest_filenames = balanced_test_df[‘filename’].valuesrnrnprint(f”Balanced Train: {len(train_filenames)} files”)rnprint(f”Balanced Test: {len(test_filenames)} files”)’), (‘language’, ”), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7bf411de50>)])]>
Step 2.2: Create a Hugging Face dataset
We’re converting our data into the Hugging Face datasets format because it’s the easiest way to work with the SFTTrainer from their Transformers library. This format is optimized for handling large datasets, especially images, because it can load them efficiently when needed. And it gives us handy tools for preprocessing, like applying our formatting function to all examples.
- code_block
- <ListValue: [StructValue([(‘code’, ‘CLASS_NAMES = [rn ‘benign_adenosis’, ‘benign_fibroadenoma’, ‘benign_phyllodes_tumor’,rn ‘benign_tubular_adenoma’, ‘malignant_ductal_carcinoma’,rn ‘malignant_lobular_carcinoma’, ‘malignant_mucinous_carcinoma’,rn ‘malignant_papillary_carcinoma’rn]rnrndef get_label_from_filename(filename):rn filename = filename.replace(‘\\’, ‘/’).lower()rn if ‘/adenosis/’ in filename: return 0rn if ‘/fibroadenoma/’ in filename: return 1rn if ‘/phyllodes_tumor/’ in filename: return 2rn if ‘/tubular_adenoma/’ in filename: return 3rn if ‘/ductal_carcinoma/’ in filename: return 4rn if ‘/lobular_carcinoma/’ in filename: return 5rn if ‘/mucinous_carcinoma/’ in filename: return 6rn if ‘/papillary_carcinoma/’ in filename: return 7rn return -1rnrntrain_data_dict = {rn ‘image’: [os.path.join(BASE_PATH, f) for f in train_filenames],rn ‘label’: [get_label_from_filename(f) for f in train_filenames]rn}rntest_data_dict = {rn ‘image’: [os.path.join(BASE_PATH, f) for f in test_filenames],rn ‘label’: [get_label_from_filename(f) for f in test_filenames]rn}rnfeatures = Features({rn ‘image’: HFImage(),rn ‘label’: ClassLabel(names=CLASS_NAMES)rn})rntrain_dataset = Dataset.from_dict(train_data_dict, features=features).cast_column(“image”, HFImage())rneval_dataset = Dataset.from_dict(test_data_dict, features=features).cast_column(“image”, HFImage())rnrnprint(train_dataset)rnprint(eval_dataset)’), (‘language’, ”), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7bf411dfa0>)])]>
Step 3: Prompt engineering
This step is where we tell the model what we want it to do. We create a clear, structured prompt that instructs the model to analyze an image and to return only the number that corresponds to a class. This prompt makes the output simple and easy to parse. We then map this format across our entire dataset.
- code_block
- <ListValue: [StructValue([(‘code’, ‘# Define the instruction promptrnPROMPT = “””Analyze this breast tissue histopathology image and classify it.rnrnClasses (0-7):rn0: benign_adenosisrn1: benign_fibroadenomarn2: benign_phyllodes_tumorrn3: benign_tubular_adenomarn4: malignant_ductal_carcinomarn5: malignant_lobular_carcinomarn6: malignant_mucinous_carcinomarn7: malignant_papillary_carcinomarnrnAnswer with only the number (0-7):”””rnrndef format_data(example):rn “””Format examples into the chat-style messages MedGemma expects.”””rn example[“messages”] = [rn {rn “role”: “user”,rn “content”: [rn {“type”: “image”},rn {“type”: “text”, “text”: PROMPT},rn ],rn },rn {rn “role”: “assistant”,rn “content”: [rn {“type”: “text”, “text”: str(example[“label”])},rn ],rn },rn ]rn return examplernrn# Apply formattingrnformatted_train = train_dataset.map(format_data, batched=False)rnformatted_eval = eval_dataset.map(format_data, batched=False)rnrnprint(“✓ Data formatted with instruction prompts”)’), (‘language’, ‘lang-py’), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7bf411d940>)])]>
Step 4: Load the model and processor
Here, we load the MedGemma model and its associated processor. The processor is a handy tool that prepares both the images and text for the model. We’ll also make two key parameter choices for efficiency:
-
torch_dtype=torch.bfloat16: As we mentioned earlier, this format ensures numerical stability. -
attn_implementation="sdpa": Scaled dot product attention is a highly optimized attention mechanism that’s available in PyTorch 2.0. Think of this mechanism as telling the model to use a super-fast, built-in engine for its most important calculation. It speeds up training and inference, and it can even automatically use more advanced backends like FlashAttention if your hardware supports it.
- code_block
- <ListValue: [StructValue([(‘code’, ‘MODEL_ID = “google/medgemma-4b-it”rnrn# Model configurationrnmodel_kwargs = dict(rn torch_dtype=torch.bfloat16,rn device_map=”auto”,rn attn_implementation=”sdpa”,rn)rnrnmodel = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)rnprocessor = AutoProcessor.from_pretrained(MODEL_ID)rnrn# Ensure right padding for trainingrnprocessor.tokenizer.padding_side = “right”‘), (‘language’, ‘lang-py’), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7bf411d6a0>)])]>
Step 5: Evaluate the baseline model
Before we invest time and compute in fine-tuning, let’s see how the pre-trained model performs on its own. This step gives us a baseline to measure our improvement against.
- code_block
- <ListValue: [StructValue([(‘code’, ‘# Helper functions to run evaluationrnaccuracy_metric = evaluate.load(“accuracy”)rnf1_metric = evaluate.load(“f1″)rnrndef compute_metrics(predictions, references):rn return {rn **accuracy_metric.compute(predictions=predictions, references=references),rn **f1_metric.compute(predictions=predictions, references=references, average=”weighted”)rn }rnrndef postprocess_prediction(text):rn “””Extract just the number from the model’s text output.”””rn digit_match = re.search(r’\b([0-7])\b’, text.strip())rn return int(digit_match.group(1)) if digit_match else -1rnrndef batch_predict(model, processor, prompts, images, batch_size=8, max_new_tokens=40):rn “””A function to run inference in batches.”””rn predictions = []rn for i in range(0, len(prompts), batch_size):rn batch_texts = prompts[i:i + batch_size]rn batch_images = [[img] for img in images[i:i + batch_size]]rnrn inputs = processor(text=batch_texts, images=images, padding=True, return_tensors=”pt”).to(“cuda”, torch.bfloat16)rn prompt_lengths = inputs[“attention_mask”].sum(dim=1)rnrn with torch.inference_mode():rn outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=processor.tokenizer.pad_token_id)rnrn for seq, length in zip(outputs, prompt_lengths):rn generated = processor.decode(seq[length:], skip_special_tokens=True)rn predictions.append(postprocess_prediction(generated))rnrn return predictionsrnrn# Prepare data for evaluationrneval_prompts = [processor.apply_chat_template([msg[0]], add_generation_prompt=True, tokenize=False) for msg in formatted_eval[“messages”]]rneval_images = formatted_eval[“image”]rneval_labels = formatted_eval[“label”]rnrn# Run baseline evaluationrnprint(“Running baseline evaluation…”)rnbaseline_preds = batch_predict(model, processor, eval_prompts, eval_images)rnbaseline_metrics = compute_metrics(baseline_preds, eval_labels)rnrnprint(f”\n{‘BASELINE RESULTS’:-^80}”)rnprint(f”Accuracy: {baseline_metrics[‘accuracy’]:.1%}”)rnprint(f”F1 Score: {baseline_metrics[‘f1′]:.3f}”)rnprint(“-” * 80)’), (‘language’, ‘lang-py’), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7bf411ddf0>)])]>
The performance of the baseline model was evaluated on both 8-class and binary (benign/malignant) classification:
-
8-Class accuracy: 32.6%
-
8-Class F1 score (weighted): 0.241
-
Binary accuracy: 59.6%
-
Binary F1 score (malignant): 0.639
This output shows that the model performs better than random chance (12.5%), but there’s significant room for improvement, especially in the fine-grained 8-class classification.
A quick detour: Few-shot learning vs. fine-tuning
Before we start training, it’s worth asking: is fine-tuning the only way? Another popular technique is few-shot learning.
Few-shot learning is like giving a smart student a few examples of a new math problem right before a test. You aren’t re-teaching them algebra, you’re just showing them the specific pattern you want them to follow by providing examples directly in the prompt. This is a powerful technique, especially when you’re using a closed model through an API where you can’t access the internal weights.
So why did we choose fine-tuning?
-
We can host the model: Because MedGemma is an open model, we have direct access to its architecture. This access lets us perform fine-tuning to create a new, permanently updated version of the model.
-
We have a good dataset: Fine-tuning lets the model learn the deep, underlying patterns in our hundreds of training images far more effectively than just showing it a few examples in a prompt.
In short, fine-tuning creates a true specialist model for our task, which is exactly what we want.
Step 6: Configure and run fine-tuning with LoRA
This is the main event! We’ll use Low-Rank Adaptation (LoRA), which is much faster and more memory-efficient than traditional fine-tuning. LoRA works by freezing the original model weights and training only a tiny set of new adapter weights. Here’s a breakdown of our parameter choices:
-
r=8: The LoRA rank. A lower rank means fewer trainable parameters, which is faster but less expressive. A higher rank has more capacity, but risks overfitting on a small dataset. Rank 8 is a great starting point that balances performance and efficiency. -
lora_alpha=16: A scaling factor for the LoRA weights. A common rule of thumb is to set it to twice the rank (2 × r). -
lora_dropout=0.1: A regularization technique. It randomly deactivates some LoRA neurons during training to prevent the model from becoming overly specialized and failing to generalize.
- code_block
- <ListValue: [StructValue([(‘code’, ‘# LoRA Configurationrnpeft_config = LoraConfig(rn r=8,rn lora_alpha=16,rn lora_dropout=0.1,rn bias=”none”,rn target_modules=”all-linear”,rn task_type=”CAUSAL_LM”,rn)rnrn# Custom data collator to handle images and textrndef collate_fn(examples):rn texts, images = [], []rn for example in examples:rn images.append([example[“image”]])rn texts.append(processor.apply_chat_template(example[“messages”], add_generation_prompt=False, tokenize=False).strip())rn batch = processor(text=texts, images=images, return_tensors=”pt”, padding=True)rn labels = batch[“input_ids”].clone()rn labels[labels == processor.tokenizer.pad_token_id] = -100rn image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map[“boi_token”])rn labels[labels == image_token_id] = -100rn labels[labels == 262144] = -100rn batch[“labels”] = labelsrn return batchrnrn# Training argumentsrntraining_args = SFTConfig(rn output_dir=”medgemma-breastcancer-finetuned”,rn num_train_epochs=5,rn per_device_train_batch_size=1,rn per_device_eval_batch_size=1,rn gradient_accumulation_steps=8,rn gradient_checkpointing=True,rn optim=”paged_adamw_8bit”,rn learning_rate=5e-4,rn lr_scheduler_type=”cosine”,rn warmup_ratio=0.03, # Warm up LR for first 3% of trainingrn max_grad_norm=0.3, # Clip gradients to prevent instabilityrn bf16=True, # Use bfloat16 precisionrn logging_steps=10,rn save_strategy=”steps”,rn save_steps=100,rn eval_strategy=”epoch”,rn push_to_hub=False,rn report_to=”none”,rn gradient_checkpointing_kwargs={“use_reentrant”: False},rn dataset_kwargs={“skip_prepare_dataset”: True},rn remove_unused_columns=False,rn label_names=[“labels”], rn)rnrn# Initialize and run the trainerrntrainer = SFTTrainer(rn model=model,rn args=training_args,rn train_dataset=formatted_train,rn eval_dataset=formatted_eval,rn peft_config=peft_config,rn processing_class=processor,rn data_collator=collate_fn,rn)rnrnprint(“Starting training…”)rntrainer.train()rntrainer.save_model()’), (‘language’, ‘lang-py’), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7bf411d7c0>)])]>
The training took about 80 minutes on the A100 GPU with VRAM 40 GB. The results looked promising, with the validation loss steadily decreasing.
Important (time saving!) tip: If your training gets interrupted for any reason (like a connection issue or exceeding resource limits), you can resume the training process from a saved checkpoint by using the resume_from_checkpoint argument in trainer.train(). Checkpoints can save you valuable time because they’re saved at every save_steps interval as defined in TrainingArguments.
Step 7: The final verdict – evaluating our fine-tuned model
After training, it’s time for the moment of truth. We’ll load our new LoRA adapter weights, merge them with the base model, and then run the same evaluation that we ran for the baseline.
- code_block
- <ListValue: [StructValue([(‘code’, ‘# Clear memory and load the final modelrndel modelrntorch.cuda.empty_cache()rngc.collect()rnrn# Load base model againrnbase_model = AutoModelForImageTextToText.from_pretrained(rn MODEL_ID,rn torch_dtype=torch.bfloat16,rn device_map=”auto”,rn attn_implementation=”sdpa”rn)rnrn# Load LoRA adapters and merge them into a single modelrnfinetuned_model = PeftModel.from_pretrained(base_model, training_args.output_dir)rnfinetuned_model = finetuned_model.merge_and_unload()rnrn# Configure for generationrnfinetuned_model.generation_config.max_new_tokens = 50rnfinetuned_model.generation_config.pad_token_id = processor_finetuned.tokenizer.pad_token_idrnfinetuned_model.config.pad_token_id = processor_finetuned.tokenizer.pad_token_idrnrn# Load the processor and run evaluationrnprocessor_finetuned = AutoProcessor.from_pretrained(training_args.output_dir)rnfinetuned_preds = batch_predict(finetuned_model, processor_finetuned, eval_prompts, eval_images, batch_size=4)rnfinetuned_metrics = compute_metrics(finetuned_preds, eval_labels)’), (‘language’, ‘lang-py’), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7bf411ddc0>)])]>
Final results
So, how did the fine tuning impact performance? Let’s look at the numbers for 8-class accuracy and macro F1.
- code_block
- <ListValue: [StructValue([(‘code’, ‘— 8-Class Classification (0-7) —rnModel Accuracy F1 (Weighted)rn———————————————–rnBaseline 32.6% 0.241rnFine-tuned 87.2% 0.865rn———————————————–rnrn— Binary (Benign/Malignant) Classification —rnModel Accuracy F1 (Malignant)rn———————————————–rnBaseline 59.6% 0.639rnFine-tuned 99.0% 0.991rn———————————————–‘), (‘language’, ”), (‘caption’, <wagtail.rich_text.RichText object at 0x7f7bf411dca0>)])]>
The results are great! After fine-tuning, we see a dramatic improvement:
-
8-Class: Accuracy jumped from 32.6% to 87.2% (+54.6%) and F1 from 0.241 to 0.865.
-
Binary: Accuracy increased from 59.6% to 99.0% (+39.4%) and F1 from 0.639 to 0.991.
This project shows the incredible power of fine-tuning modern foundation models. We took a generalist AI that was already pre-trained on relevant medical data, gave it a small, specialized dataset, and taught it a new skill with remarkable efficiency. The journey from a generic model to a specialized classifier is more accessible than ever, opening up exciting possibilities for AI in medicine and beyond.
All of the information is available in the Finetune Notebook. You can run it in with a GPU instance on Vertex AI Workbench.
Want to take it to production? Don’t forget to catch the upcoming post, which shows you how to bring the fine tuning and evaluation to Cloud Run jobs.
I hope this guide was helpful. Happy coding!
Special thanks to Fereshteh Mahvar and Dave Steiner from the MedGemma team for their helpful review and feedback on this post.
1 IEEE Transactions on Biomedical Engineering, vol. 63, no. 7, pp. 1455-1462, 2016
Read More for the details.
