Local subgraphs analysis

import itertools
import json

import polars as pl
import networkx as nx
import seaborn as sns
import pandas as pd
import random
import openai
from langchain_core.utils.json import parse_json_markdown
from functools import partial
from tqdm.auto import tqdm

sns.set_theme("notebook")
# load the graph
with open("../../data/datasets/pl/graph/data/judgment_graph.json") as file:
    g_data = json.load(file)

g = nx.node_link_graph(g_data)
src_nodes, target_nodes = nx.bipartite.sets(g)
ds = pl.scan_parquet("../../data/datasets/pl/raw/*.parquet")

Investigate local parts of graph

Extract local graph

# sets node degree as its attribute
nx.set_node_attributes(g, dict(nx.degree(g)), "degree")
deg_sorted_nodes = sorted(g.nodes(data=True), key=lambda x: x[1]["degree"], reverse=True)

def get_legal_bases_with_deg(deg: int) -> list[int]:
    return [n_id for n_id, data in deg_sorted_nodes if data["degree"] == deg and n_id in target_nodes]

def get_judgments_of_legal_base(legal_base_id: int) -> list[int]:
    dg = g.to_directed()
    src_nodes = list(dg.predecessors(legal_base_id))
    target_nodes = list(itertools.chain.from_iterable(dg.successors(n_id) for n_id in src_nodes))
    return src_nodes + target_nodes
LB = get_legal_bases_with_deg(4)[0]
TITLE = g.nodes[LB]["title"]
neighborhood = get_judgments_of_legal_base(LB)
print(f"Found nodes: {len(neighborhood)=}")
sg = nx.induced_subgraph(g, nbunch=neighborhood)
cases = pd.DataFrame.from_dict(dict(sg.nodes(data=True)), orient="index").reset_index().sort_values(["node_type", "date"])
case_ids = cases["_id"].dropna().tolist()

cases_text= ds.select(["_id", "text"]).filter(pl.col("_id").is_in(case_ids)).collect()
cases = cases[["index", "_id"]].merge(cases_text.to_pandas(), on="_id", how="right")
cases.head()

Summarize judgments

client = openai.OpenAI(
    base_url="http://localhost:8000/v1",
    api_key = "sk-no-key-required"
)

llm_input = "\n\n".join([t[:3000] for t in cases_text.to_dict(as_series=False)["text"]])

INPUT_PROMPT = """
You are an AI tasked with summarizing multiple Polish court judgments. Always response in English, use formal language.
First, provide an overall_summary which is a single sentence that encapsulates the common topic of all the judgments, don't be too general.
Then, for each judgment, provide a one-sentence judgment_summary, including the main reason for the decision, preserve order of judgments. 
For each judgment provide keyphrases summarizing it.

Summarize followint judgments:
====
{context}
====

Format response as JSON:
``json
{{
    overall_summary: string,
    judgment_summaries: list of string,
    keyphrases: list of lists of string,
}}
```
"""

completion = client.chat.completions.create(
model="not-required",
messages=[
    {"role": "user", "content": INPUT_PROMPT.format(context=llm_input)}
]
)

response = completion.choices[0].message.content

try:
    summary = parse_json_markdown(response)
    print(summary)
except Exception:
    print("Couldn't parse, raw response:")
    print(response)
iid_2_index = {item["_id"]: item["index"] for item in cases[["index", "_id"]].to_dict("records")}
summary_node_attr = {iid_2_index[iid]: text for iid, text in zip(cases_text["_id"].to_list(), summary["judgment_summaries"])}
kp_node_attr = {iid_2_index[iid]: text for iid, text in zip(cases_text["_id"].to_list(), summary["keyphrases"])}
nx.set_node_attributes(sg, summary_node_attr, name="summary")
nx.set_node_attributes(sg, kp_node_attr, name="keyphrases")

Visualize

from bokeh.io import output_notebook, show
from bokeh.models import Range1d, Circle, ColumnDataSource, MultiLine, LabelSet
from bokeh.plotting import figure
from bokeh.plotting import from_networkx
from bokeh.transform import linear_cmap
output_notebook()
HOVER_TOOLTIPS = [
    ("Date", "@date"),
    ("Summary", "@summary"),
    ("ISAP", "@isap_id"),
]

COLOR_MAP = {
    "judgment": 0,
    "legal_base": 1,
}

nx.set_node_attributes(sg, {n_id: COLOR_MAP[n_data["node_type"]] for n_id, n_data in sg.nodes(data=True)}, name="nt")
color_by_this_attribute = 'nt'
color_palette = ("#EA1D15", "#15E2EA")

plot = figure(
    tooltips = HOVER_TOOLTIPS,
    tools="pan,wheel_zoom,save,reset", 
    active_scroll='wheel_zoom', 
    x_range=Range1d(-10.1, 10.1), 
    y_range=Range1d(-10.1, 10.1),
    width=1_200,
    height=600,
)

plot.xgrid.visible = False
plot.ygrid.visible = False
plot.xaxis.visible = False
plot.yaxis.visible = False
n_ids = [n_id for n_id in sg.nodes if sg.nodes[n_id]["node_type"]=="judgment"]
n_ids_2 = [n_id for n_id in sg.nodes if sg.nodes[n_id]["node_type"]=="legal_base"]
network_graph = from_networkx(sg, partial(nx.bipartite_layout, nodes=n_ids), scale=10, center=(0, 0))

#Set node sizes and colors according to node degree (color as spectrum of color palette)
minimum_value_color = min(network_graph.node_renderer.data_source.data[color_by_this_attribute])
maximum_value_color = max(network_graph.node_renderer.data_source.data[color_by_this_attribute])
network_graph.node_renderer.glyph = Circle(radius=0.30, fill_color=linear_cmap(color_by_this_attribute, color_palette, minimum_value_color, maximum_value_color))

#Set edge opacity and width
network_graph.edge_renderer.glyph = MultiLine(line_alpha=0.5, line_width=1)

x, y = zip(*network_graph.layout_provider.graph_layout.values())
node_labels = [",".join(sg.nodes[index]["keyphrases"]) for index in n_ids + n_ids_2]
source = ColumnDataSource({'x': x, 'y': y, 'name': [node_labels[i] for i in range(len(x))]})
labels = LabelSet(x='x', y='y', text='name', source=source, background_fill_color='white', text_font_size='14px', background_fill_alpha=1.0)
plot.renderers.append(labels)

plot.renderers.append(network_graph)

show(plot)

Community detection

def connected_legal_bases(g: nx.Graph, nbunch: list):
    nbunch = set(nbunch)
    return list(set(target_id for src_id, target_id in g.edges if src_id in nbunch))
def sample_subgraph_randomly(g: nx.Graph, k: int) -> nx.Graph:
    sampled_nodes = random.sample(list(src_nodes), k=k)
    subgraph_node_ids = sampled_nodes + connected_legal_bases(g, sampled_nodes)
    return nx.induced_subgraph(g, nbunch=subgraph_node_ids)
# sg = sample_subgraph_randomly(g, k=5_000)
sg = g
print(f"{len(sg.edges)=}")
communities = list(nx.community.louvain_communities(sg, resolution=3))
# communities = list(nx.community.label_propagation_communities(sg))
print(f"{len(communities)=}")
ax = sns.histplot([len(c) for c in communities])
ax.set(title="Community size distribution", yscale="log")
communitiy_sizes = {idx: len(c) for idx, c in enumerate(communities)}