Tutorial: Fine-tuning Your First Legal LLM¶
Learn how to fine-tune large language models for legal document analysis using JuDDGES infrastructure. This tutorial covers preparing instruction datasets, configuring training with PEFT/LoRA, and evaluating model performance.
Table of Contents¶
- Learning Objectives
- Prerequisites
- Step 1: Prepare Instruction Dataset
- Step 2: Configure Training
- Step 3: Run Fine-tuning
- Step 4: Evaluate Model
- Step 5: Use Your Fine-tuned Model
- Best Practices
- Troubleshooting
- Summary
Learning Objectives¶
By the end of this tutorial, you will:
- ✅ Prepare instruction datasets for legal tasks
- ✅ Configure PEFT/LoRA for efficient fine-tuning
- ✅ Train a model using JuDDGES infrastructure
- ✅ Evaluate fine-tuned model performance
- ✅ Deploy and use your custom legal LLM
Estimated Time: 60 minutes (+ training time) GPU Required: Yes (40GB+ VRAM recommended)
Prerequisites¶
Required Knowledge¶
- Completion of Tutorial 1 and Tutorial 2
- Understanding of language models and fine-tuning concepts
- Familiarity with command line and Python
Required Hardware¶
- GPU: NVIDIA GPU with 40GB+ VRAM (A100, A6000, or similar)
- RAM: 64GB+ system RAM
- Storage: 100GB+ free space for models and datasets
Required Software¶
- JuDDGES environment with CUDA support
- DVC installed and configured
Step 1: Prepare Instruction Dataset¶
Understanding Instruction Format¶
Instruction datasets follow this format:
{
"instruction": "Wyodrębnij sygnaturę sprawy z wyroku.",
"input": "Sąd Okręgowy w Warszawie, sygn. II C 123/2023...",
"output": "II C 123/2023"
}
Create Sample Dataset¶
"""Create an instruction dataset for legal information extraction."""
from datasets import Dataset, DatasetDict
from rich.console import Console
import random
console = Console()
# Sample legal documents with extraction tasks
samples = [
{
"instruction": "Wyodrębnij datę wyroku w formacie ISO 8601.",
"input": "Wyrok z dnia 15 marca 2023 roku. Sąd Okręgowy w Warszawie orzekł...",
"output": "2023-03-15",
},
{
"instruction": "Wyodrębnij nazwę sądu.",
"input": "W imieniu Rzeczypospolitej Polskiej. Sąd Okręgowy w Krakowie...",
"output": "Sąd Okręgowy w Krakowie",
},
{
"instruction": "Wyodrębnij sygnaturę sprawy.",
"input": "Sygnatura akt: II C 456/2023. Sąd rozpoznał sprawę...",
"output": "II C 456/2023",
},
# Add more samples...
]
# Create dataset
dataset = Dataset.from_list(samples)
# Split into train/test
dataset = dataset.train_test_split(test_size=0.2, seed=42)
# Save
dataset.save_to_disk("./data/instruction_dataset")
console.print(f"[green]✓ Created dataset with {len(dataset['train'])} training examples[/green]")
Load Existing Dataset¶
from datasets import load_dataset
# Load pre-built Swiss franc loans dataset
dataset = load_dataset("JuDDGES/swiss_franc_loans_instruct")
print(f"Train: {len(dataset['train'])} examples")
print(f"Test: {len(dataset['test'])} examples")
# Inspect sample
sample = dataset['train'][0]
print(f"Instruction: {sample['instruction'][:100]}...")
print(f"Output: {sample['output'][:100]}...")
Step 2: Configure Training¶
Create Configuration File¶
Create configs/my_finetuning.yaml:
# Model configuration
llm:
model_name: "meta-llama/Llama-3.2-3B-Instruct"
tokenizer_name: "meta-llama/Llama-3.2-3B-Instruct"
torch_dtype: "bfloat16"
device_map: "auto"
# Dataset configuration
dataset:
name: "JuDDGES/swiss_franc_loans_instruct"
instruction_field: "instruction"
input_field: "input"
output_field: "output"
# PEFT/LoRA configuration
peft:
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Training configuration
training:
output_dir: "./outputs/llama-3.2-3b-legal"
num_train_epochs: 3
per_device_train_batch_size: 4
gradient_accumulation_steps: 4
learning_rate: 2e-4
warmup_steps: 100
logging_steps: 10
save_steps: 500
evaluation_strategy: "steps"
eval_steps: 500
fp16: false
bf16: true
max_grad_norm: 0.3
optim: "paged_adamw_8bit"
Verify Configuration¶
from omegaconf import OmegaConf
# Load config
config = OmegaConf.load("configs/my_finetuning.yaml")
# Verify settings
print(f"Model: {config.llm.model_name}")
print(f"Dataset: {config.dataset.name}")
print(f"LoRA rank: {config.peft.lora_r}")
print(f"Epochs: {config.training.num_train_epochs}")
print(f"Batch size: {config.training.per_device_train_batch_size}")
Step 3: Run Fine-tuning¶
Using the Fine-tuning Script¶
# Set GPU
export CUDA_VISIBLE_DEVICES=0
# Run fine-tuning
python scripts/sft/fine_tune_llm.py \
--config configs/my_finetuning.yaml \
--output_dir ./outputs/my-legal-model
Using DVC Pipeline¶
# Configure DVC stage
dvc stage add -n finetune \
-d data/instruction_dataset \
-d configs/my_finetuning.yaml \
-o outputs/my-legal-model \
python scripts/sft/fine_tune_llm.py
# Run pipeline
dvc repro finetune
Monitor Training¶
"""Monitor training progress."""
from pathlib import Path
import pandas as pd
import plotly.express as px
# Load training logs
log_file = Path("outputs/my-legal-model/trainer_state.json")
if log_file.exists():
import json
with open(log_file) as f:
state = json.load(f)
# Extract loss history
history = state["log_history"]
df = pd.DataFrame(history)
# Plot training loss
fig = px.line(df, x="step", y="loss", title="Training Loss")
fig.show()
print(f"Current epoch: {state['epoch']}")
print(f"Global step: {state['global_step']}")
print(f"Best metric: {state.get('best_metric', 'N/A')}")
Expected Training Time:
- Llama-3.2-3B: ~2-3 hours on A100 (1000 examples)
- Llama-3.1-8B: ~6-8 hours on A100 (1000 examples)
Step 4: Evaluate Model¶
Run Evaluation¶
# Evaluate on test set
python scripts/sft/evaluate.py \
--model_path ./outputs/my-legal-model \
--dataset JuDDGES/swiss_franc_loans_instruct \
--split test \
--output_file ./outputs/evaluation_results.json
Evaluation Metrics¶
"""Analyze evaluation results."""
import json
from rich.console import Console
from rich.table import Table
console = Console()
# Load results
with open("./outputs/evaluation_results.json") as f:
results = json.load(f)
# Display metrics
table = Table(title="Evaluation Metrics")
table.add_column("Metric", style="cyan")
table.add_column("Score", style="green")
for metric, score in results["metrics"].items():
table.add_row(metric, f"{score:.4f}")
console.print(table)
# Show example predictions
console.print("\n[bold]Example Predictions:[/bold]")
for i, example in enumerate(results["examples"][:3], 1):
console.print(f"\n[cyan]Example {i}:[/cyan]")
console.print(f"Input: {example['input'][:100]}...")
console.print(f"Expected: {example['expected'][:100]}...")
console.print(f"Predicted: {example['predicted'][:100]}...")
console.print(f"Match: {example['match']}")
Compare with Baseline¶
"""Compare fine-tuned model with base model."""
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load both models
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
finetuned_model = AutoModelForCausalLM.from_pretrained("./outputs/my-legal-model")
# Test on same examples
test_input = "Wyodrębnij datę wyroku: Wyrok z dnia 15 marca 2023..."
# Generate from both
base_output = generate(base_model, test_input)
finetuned_output = generate(finetuned_model, test_input)
print(f"Base model: {base_output}")
print(f"Fine-tuned: {finetuned_output}")
Step 5: Use Your Fine-tuned Model¶
Load and Inference¶
"""Use your fine-tuned model for inference."""
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load model
model_path = "./outputs/my-legal-model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
)
def extract_information(instruction: str, text: str) -> str:
"""Extract information from legal text."""
# Format prompt
prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{text}\n\n### Output:\n"
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1,
do_sample=False,
)
# Decode
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract output section
if "### Output:" in result:
result = result.split("### Output:")[1].strip()
return result
# Test
instruction = "Wyodrębnij datę wyroku w formacie ISO 8601."
text = "Wyrok Sądu Okręgowego w Warszawie z dnia 15 marca 2023 roku..."
output = extract_information(instruction, text)
print(f"Extracted: {output}")
Deploy as API¶
"""Simple FastAPI deployment."""
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
# Load model once at startup
model = load_model("./outputs/my-legal-model")
class ExtractionRequest(BaseModel):
instruction: str
text: str
@app.post("/extract")
def extract(request: ExtractionRequest):
result = extract_information(
request.instruction,
request.text
)
return {"result": result}
# Run: uvicorn api:app --reload
Best Practices¶
1. Dataset Quality¶
# ✅ Good: Clear, specific instructions
{
"instruction": "Extract the court name from the judgment header.",
"input": "Court of Appeal, London...",
"output": "Court of Appeal, London"
}
# ❌ Bad: Vague, ambiguous
{
"instruction": "Get the important info.",
"input": "Some text...",
"output": "Stuff"
}
2. Hyperparameter Tuning¶
Start with these defaults:
# For 3B-8B models
lora_r = 16 # Rank (higher = more capacity)
lora_alpha = 32 # Scaling factor
learning_rate = 2e-4 # Learning rate
batch_size = 4 # Per device
# For 11B-70B models
lora_r = 32
lora_alpha = 64
learning_rate = 1e-4
batch_size = 1
3. Training Monitoring¶
# Check for overfitting
if eval_loss starts increasing:
# Reduce epochs or add regularization
num_epochs = 2 # Instead of 3
lora_dropout = 0.1 # Instead of 0.05
# Check for underfitting
if train_loss plateaus high:
# Increase capacity or learning rate
lora_r = 32 # Instead of 16
learning_rate = 3e-4 # Instead of 2e-4
4. Evaluation¶
Always evaluate on:
- ✅ Held-out test set (not used in training)
- ✅ Real-world examples
- ✅ Edge cases and difficult examples
- ✅ Multiple metrics (exact match, F1, BLEU)
Troubleshooting¶
Issue: "CUDA out of memory"¶
Solutions:
# 1. Reduce batch size
per_device_train_batch_size = 1
gradient_accumulation_steps = 16
# 2. Use 8-bit quantization
load_in_8bit = true
# 3. Enable gradient checkpointing
gradient_checkpointing = true
# 4. Use smaller model
model_name = "meta-llama/Llama-3.2-3B-Instruct" # Instead of 8B
Issue: "Model not learning"¶
Solutions:
# 1. Increase learning rate
learning_rate = 5e-4 # Instead of 2e-4
# 2. Increase LoRA rank
lora_r = 32 # Instead of 16
# 3. Train longer
num_train_epochs = 5 # Instead of 3
# 4. Check data quality
# Ensure instructions are clear and outputs are correct
Issue: "Model overfitting"¶
Solutions:
# 1. Reduce epochs
num_train_epochs = 2
# 2. Increase dropout
lora_dropout = 0.1
# 3. Add more training data
# Collect more diverse examples
# 4. Use early stopping
early_stopping_patience = 3
Summary¶
You've successfully fine-tuned a legal LLM!
What You've Learned¶
✅ Dataset Preparation: Created instruction datasets ✅ Configuration: Set up PEFT/LoRA parameters ✅ Training: Fine-tuned models efficiently ✅ Evaluation: Assessed model performance ✅ Deployment: Used fine-tuned models in production
Key Metrics¶
Your fine-tuned model should achieve:
- Exact Match: 70-85% on legal extraction tasks
- F1 Score: 80-90% on field-level evaluation
- Training Time: 2-8 hours depending on model size
- Inference: <1s per prediction
Next Steps¶
- Tutorial 4: Advanced Information Extraction - Complex schemas and pipelines
- Tutorial 5: End-to-End Project - Build complete system
Related Documentation¶
Last Updated: 2025-10-11 | Version: 1.0 | Status: Published
Difficulty: Advanced | Prerequisites: GPU access, Tutorials 1-2