import itertools
import json
import polars as pl
import networkx as nx
import seaborn as sns
import pandas as pd
import random
import openai
from langchain_core.utils.json import parse_json_markdown
from functools import partial
from tqdm.auto import tqdm
"notebook") sns.set_theme(
Local subgraphs analysis
# load the graph
with open("../../data/datasets/pl/graph/data/judgment_graph.json") as file:
= json.load(file)
g_data
= nx.node_link_graph(g_data)
g = nx.bipartite.sets(g) src_nodes, target_nodes
= pl.scan_parquet("../../data/datasets/pl/raw/*.parquet") ds
Investigate local parts of graph
Extract local graph
# sets node degree as its attribute
dict(nx.degree(g)), "degree") nx.set_node_attributes(g,
= sorted(g.nodes(data=True), key=lambda x: x[1]["degree"], reverse=True)
deg_sorted_nodes
def get_legal_bases_with_deg(deg: int) -> list[int]:
return [n_id for n_id, data in deg_sorted_nodes if data["degree"] == deg and n_id in target_nodes]
def get_judgments_of_legal_base(legal_base_id: int) -> list[int]:
= g.to_directed()
dg = list(dg.predecessors(legal_base_id))
src_nodes = list(itertools.chain.from_iterable(dg.successors(n_id) for n_id in src_nodes))
target_nodes return src_nodes + target_nodes
= get_legal_bases_with_deg(4)[0]
LB = g.nodes[LB]["title"]
TITLE = get_judgments_of_legal_base(LB)
neighborhood print(f"Found nodes: {len(neighborhood)=}")
= nx.induced_subgraph(g, nbunch=neighborhood) sg
= pd.DataFrame.from_dict(dict(sg.nodes(data=True)), orient="index").reset_index().sort_values(["node_type", "date"])
cases = cases["_id"].dropna().tolist()
case_ids
= ds.select(["_id", "text"]).filter(pl.col("_id").is_in(case_ids)).collect()
cases_text= cases[["index", "_id"]].merge(cases_text.to_pandas(), on="_id", how="right")
cases cases.head()
Summarize judgments
= openai.OpenAI(
client ="http://localhost:8000/v1",
base_url= "sk-no-key-required"
api_key
)
= "\n\n".join([t[:3000] for t in cases_text.to_dict(as_series=False)["text"]])
llm_input
= """
INPUT_PROMPT You are an AI tasked with summarizing multiple Polish court judgments. Always response in English, use formal language.
First, provide an overall_summary which is a single sentence that encapsulates the common topic of all the judgments, don't be too general.
Then, for each judgment, provide a one-sentence judgment_summary, including the main reason for the decision, preserve order of judgments.
For each judgment provide keyphrases summarizing it.
Summarize followint judgments:
====
{context}
====
Format response as JSON:
``json
{{
overall_summary: string,
judgment_summaries: list of string,
keyphrases: list of lists of string,
}}
```
"""
= client.chat.completions.create(
completion ="not-required",
model=[
messages"role": "user", "content": INPUT_PROMPT.format(context=llm_input)}
{
]
)
= completion.choices[0].message.content
response
try:
= parse_json_markdown(response)
summary print(summary)
except Exception:
print("Couldn't parse, raw response:")
print(response)
= {item["_id"]: item["index"] for item in cases[["index", "_id"]].to_dict("records")}
iid_2_index = {iid_2_index[iid]: text for iid, text in zip(cases_text["_id"].to_list(), summary["judgment_summaries"])}
summary_node_attr = {iid_2_index[iid]: text for iid, text in zip(cases_text["_id"].to_list(), summary["keyphrases"])}
kp_node_attr ="summary")
nx.set_node_attributes(sg, summary_node_attr, name="keyphrases") nx.set_node_attributes(sg, kp_node_attr, name
Translate legal-legal base names
= """
TRANSLATION_PROMPT You are an AI assistant asked to translate name of Polish legal acts to Ensligh.
Provide shortest possible translations, remove dates and unimportant details.
Return only translation, without any additional output.
Example:
- Ustawa z dnia 23 kwietnia 1964 r. - Kodeks cywilny
- Civil Code (1964)
Translate this legal act name: {context}
"""
= {}
results for iid, name in tqdm(nx.get_node_attributes(sg, "title").items()):
= client.chat.completions.create(
completion ="not-required",
model=[
messages"role": "user", "content": TRANSLATION_PROMPT.format(context=name)}
{
]
)= [completion.choices[0].message.content]
results[iid]
"keyphrases") nx.set_node_attributes(sg, results,
Visualize
from bokeh.io import output_notebook, show
from bokeh.models import Range1d, Circle, ColumnDataSource, MultiLine, LabelSet
from bokeh.plotting import figure
from bokeh.plotting import from_networkx
from bokeh.transform import linear_cmap
output_notebook()
= [
HOVER_TOOLTIPS "Date", "@date"),
("Summary", "@summary"),
("ISAP", "@isap_id"),
(
]
= {
COLOR_MAP "judgment": 0,
"legal_base": 1,
}
"node_type"]] for n_id, n_data in sg.nodes(data=True)}, name="nt")
nx.set_node_attributes(sg, {n_id: COLOR_MAP[n_data[= 'nt'
color_by_this_attribute = ("#EA1D15", "#15E2EA")
color_palette
= figure(
plot = HOVER_TOOLTIPS,
tooltips ="pan,wheel_zoom,save,reset",
tools='wheel_zoom',
active_scroll=Range1d(-10.1, 10.1),
x_range=Range1d(-10.1, 10.1),
y_range=1_200,
width=600,
height
)
= False
plot.xgrid.visible = False
plot.ygrid.visible = False
plot.xaxis.visible = False
plot.yaxis.visible = [n_id for n_id in sg.nodes if sg.nodes[n_id]["node_type"]=="judgment"]
n_ids = [n_id for n_id in sg.nodes if sg.nodes[n_id]["node_type"]=="legal_base"]
n_ids_2 = from_networkx(sg, partial(nx.bipartite_layout, nodes=n_ids), scale=10, center=(0, 0))
network_graph
#Set node sizes and colors according to node degree (color as spectrum of color palette)
= min(network_graph.node_renderer.data_source.data[color_by_this_attribute])
minimum_value_color = max(network_graph.node_renderer.data_source.data[color_by_this_attribute])
maximum_value_color = Circle(radius=0.30, fill_color=linear_cmap(color_by_this_attribute, color_palette, minimum_value_color, maximum_value_color))
network_graph.node_renderer.glyph
#Set edge opacity and width
= MultiLine(line_alpha=0.5, line_width=1)
network_graph.edge_renderer.glyph
= zip(*network_graph.layout_provider.graph_layout.values())
x, y = [",".join(sg.nodes[index]["keyphrases"]) for index in n_ids + n_ids_2]
node_labels = ColumnDataSource({'x': x, 'y': y, 'name': [node_labels[i] for i in range(len(x))]})
source = LabelSet(x='x', y='y', text='name', source=source, background_fill_color='white', text_font_size='14px', background_fill_alpha=1.0)
labels
plot.renderers.append(labels)
plot.renderers.append(network_graph)
show(plot)
Community detection
def connected_legal_bases(g: nx.Graph, nbunch: list):
= set(nbunch)
nbunch return list(set(target_id for src_id, target_id in g.edges if src_id in nbunch))
def sample_subgraph_randomly(g: nx.Graph, k: int) -> nx.Graph:
= random.sample(list(src_nodes), k=k)
sampled_nodes = sampled_nodes + connected_legal_bases(g, sampled_nodes)
subgraph_node_ids return nx.induced_subgraph(g, nbunch=subgraph_node_ids)
# sg = sample_subgraph_randomly(g, k=5_000)
= g
sg print(f"{len(sg.edges)=}")
= list(nx.community.louvain_communities(sg, resolution=3))
communities # communities = list(nx.community.label_propagation_communities(sg))
print(f"{len(communities)=}")
= sns.histplot([len(c) for c in communities])
ax set(title="Community size distribution", yscale="log") ax.
= {idx: len(c) for idx, c in enumerate(communities)} communitiy_sizes