SFT results inspection

import warnings
import json
from multiprocessing import Pool
from statistics import mean
from typing import Any
from pathlib import Path

import pandas as pd
from tqdm.auto import tqdm
from ipywidgets import interact

from juddges.utils.misc import parse_yaml
from juddges.metrics.info_extraction import evaluate_extraction

pd.options.display.float_format = '{:,.3f}'.format
warnings.filterwarnings('ignore', message="To copy construct from a tensor, it is recommended to use")

Compare metrics

results = []
for f in  Path("../../data/experiments/predict/pl-court-instruct").glob("metrics_*.json"):
    model_name = f.stem.replace("metrics_", "")
    with f.open() as file:
        m_res = json.load(file)
        results.append(
            {"llm": model_name}
            | {"full_text_chrf": m_res["full_text_chrf"]}
            | m_res["field_chrf"]
        )

pd.DataFrame(results).sort_values("llm")
llm full_text_chrf court_name date department_name judges legal_bases recorder signature
2 Meta-Llama-3-8B-Instruct 0.247 0.862 0.971 0.833 0.882 0.287 0.805 0.778
0 Mistral-7B-Instruct-v0.2 0.432 0.839 0.922 0.850 0.879 0.333 0.837 0.145
3 Mistral-7B-Instruct-v0.2-fine-tuned 0.772 0.987 0.990 0.965 0.952 0.600 0.979 0.972
4 Unsloth-Llama-3-8B-Instruct-fine-tuned 0.828 0.995 0.989 0.986 0.977 0.601 0.993 0.994
1 Unsloth-Mistral-7B-Instruct-v0.3 0.477 0.830 0.987 0.900 0.870 0.419 0.943 0.567
5 Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned 0.798 0.995 0.988 0.986 0.967 0.608 0.987 0.976

Inspect results

OUTPUTS_PATH = "../../data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct-fine-tuned.json"

with open(OUTPUTS_PATH) as file:
    data = json.load(file)
def eval_item(item: dict[str, Any]) -> dict[str, Any]:
    item["metrics"] = evaluate_extraction([item])
    item["metrics"]["mean_field"] = mean(item["metrics"]["field_chrf"].values())
    item["gold"] = parse_yaml(item["gold"])
    try:
        item["answer"] = parse_yaml(item["answer"])
    except:
        item["answer"] = None
    return item

num_invalid_answers = 0
results = []
with Pool(10) as pool:
    for item in tqdm(pool.imap(eval_item, data), total=len(data)):
        results.append(item)
        if item["answer"] is None:
            num_invalid_answers += 1

print(f"Number of invalid answers: {num_invalid_answers} / {len(data)}")
Number of invalid answers: 224 / 2000
data_valid = [item for item in results if item["answer"] is not None]
data_valid = sorted(data_valid, key=lambda x: x["metrics"]["mean_field"])

def item_to_df(idx: int) -> pd.DataFrame:
    item = data_valid[idx]
    return pd.DataFrame({
        "gold": item["gold"],
        "answer": item["answer"],
        "metrics": item["metrics"]["field_chrf"],
    })


interact(item_to_df, idx=range(len(data_valid)));