"""Axis registry and character state snapshot repository operations."""
from __future__ import annotations
import json
import sqlite3
from datetime import UTC, datetime
from secrets import randbelow
from typing import Any, NoReturn, cast
from pipeworks_ipc import compute_payload_hash
from mud_server.db.connection import connection_scope
from mud_server.db.constants import DEFAULT_AXIS_SCORE
from mud_server.db.errors import (
DatabaseError,
DatabaseOperationContext,
DatabaseReadError,
DatabaseWriteError,
)
from mud_server.db.types import AxisRegistrySeedStats
def _raise_read_error(operation: str, exc: Exception, *, details: str | None = None) -> NoReturn:
"""Raise a typed repository read error while preserving chained cause."""
if isinstance(exc, DatabaseError):
raise exc
raise DatabaseReadError(
context=DatabaseOperationContext(operation=operation, details=details),
cause=exc,
) from exc
def _raise_write_error(operation: str, exc: Exception, *, details: str | None = None) -> NoReturn:
"""Raise a typed repository write error while preserving chained cause."""
if isinstance(exc, DatabaseError):
raise exc
raise DatabaseWriteError(
context=DatabaseOperationContext(operation=operation, details=details),
cause=exc,
) from exc
def _generate_state_seed() -> int:
"""Return a non-zero random seed for character state snapshots."""
return randbelow(2_147_483_647) + 1
def _extract_axis_ordering_values(axis_data: dict[str, Any]) -> list[str]:
"""Extract ordering values from axis policy payloads."""
ordering = (axis_data or {}).get("ordering")
if not isinstance(ordering, dict):
return []
values = ordering.get("values")
if not isinstance(values, list):
return []
return [str(value) for value in values]
[docs]
def seed_axis_registry(
*,
world_id: str,
axes_payload: dict[str, Any],
thresholds_payload: dict[str, Any],
) -> AxisRegistrySeedStats:
"""Insert or update axis and axis-value rows from policy payloads."""
axes_definitions = axes_payload.get("axes") or {}
thresholds_definitions = thresholds_payload.get("axes") or {}
axes_upserted = 0
axis_values_inserted = 0
axes_missing_thresholds = 0
axis_values_skipped = 0
try:
with connection_scope(write=True) as conn:
cursor = conn.cursor()
for axis_name, axis_data in axes_definitions.items():
axis_data = axis_data or {}
ordering = axis_data.get("ordering")
ordering_json = json.dumps(ordering, sort_keys=True) if ordering else None
cursor.execute(
"""
INSERT INTO axis (world_id, name, description, ordering_json)
VALUES (?, ?, ?, ?)
ON CONFLICT(world_id, name) DO UPDATE SET
description = excluded.description,
ordering_json = excluded.ordering_json
""",
(
world_id,
axis_name,
axis_data.get("description"),
ordering_json,
),
)
axes_upserted += 1
cursor.execute(
"SELECT id FROM axis WHERE world_id = ? AND name = ? LIMIT 1",
(world_id, axis_name),
)
axis_row = cursor.fetchone()
if not axis_row:
axis_values_skipped += 1
continue
axis_id = int(axis_row[0])
thresholds = thresholds_definitions.get(axis_name)
if not isinstance(thresholds, dict):
axes_missing_thresholds += 1
continue
values = thresholds.get("values") or {}
if not isinstance(values, dict):
axis_values_skipped += 1
continue
ordering_values = _extract_axis_ordering_values(axis_data)
ordinal_map = {value: index for index, value in enumerate(ordering_values)}
cursor.execute("DELETE FROM axis_value WHERE axis_id = ?", (axis_id,))
for value_name, value_bounds in values.items():
value_bounds = value_bounds or {}
min_score = value_bounds.get("min")
max_score = value_bounds.get("max")
min_score = float(min_score) if min_score is not None else None
max_score = float(max_score) if max_score is not None else None
cursor.execute(
"""
INSERT INTO axis_value (axis_id, value, min_score, max_score, ordinal)
VALUES (?, ?, ?, ?, ?)
""",
(
axis_id,
str(value_name),
min_score,
max_score,
ordinal_map.get(str(value_name)),
),
)
axis_values_inserted += 1
except Exception as exc:
_raise_write_error(
"axis.seed_axis_registry",
exc,
details=f"world_id={world_id!r}",
)
return AxisRegistrySeedStats(
axes_upserted=axes_upserted,
axis_values_inserted=axis_values_inserted,
axes_missing_thresholds=axes_missing_thresholds,
axis_values_skipped=axis_values_skipped,
)
def _get_axis_policy_hash(world_id: str) -> str | None:
"""Return canonical manifest+axis policy hash for one world.
This helper is runtime-facing and intentionally DB-first. It resolves
world-scope activation pointers for:
1. ``manifest_bundle:world.manifests:<world_id>``
2. ``axis_bundle:axis.bundles:<bundle_id>`` selected by manifest
and hashes those canonical payloads.
"""
from mud_server.db import policy_repo
manifest_policy_id = f"manifest_bundle:world.manifests:{world_id}"
world_activations = {
str(row["policy_id"]): row
for row in policy_repo.list_policy_activations(world_id=world_id, client_profile="")
}
manifest_activation = world_activations.get(manifest_policy_id)
if manifest_activation is None:
return None
manifest_row = policy_repo.get_policy(
policy_id=manifest_policy_id,
variant=str(manifest_activation["variant"]),
)
if manifest_row is None:
return None
manifest_content = manifest_row.get("content")
if not isinstance(manifest_content, dict):
return None
manifest_payload = manifest_content.get("manifest")
if not isinstance(manifest_payload, dict):
return None
axis_active_bundle = (manifest_payload.get("axis") or {}).get("active_bundle")
if not isinstance(axis_active_bundle, dict):
return None
axis_bundle_id = str(axis_active_bundle.get("id", "")).strip()
if not axis_bundle_id:
return None
axis_policy_id = f"axis_bundle:axis.bundles:{axis_bundle_id}"
axis_activation = world_activations.get(axis_policy_id)
if axis_activation is None:
return None
axis_row = policy_repo.get_policy(
policy_id=axis_policy_id,
variant=str(axis_activation["variant"]),
)
if axis_row is None:
return None
axis_content = axis_row.get("content")
if not isinstance(axis_content, dict):
return None
return str(
compute_payload_hash(
{
"manifest": manifest_payload,
"axis_bundle": axis_content,
}
)
)
def _resolve_axis_label_for_score(cursor: sqlite3.Cursor, axis_id: int, score: float) -> str | None:
"""Resolve axis score to a label using axis_value thresholds."""
cursor.execute(
"""
SELECT value
FROM axis_value
WHERE axis_id = ?
AND (? >= min_score OR min_score IS NULL)
AND (? <= max_score OR max_score IS NULL)
ORDER BY
CASE WHEN ordinal IS NULL THEN 1 ELSE 0 END,
ordinal,
min_score
LIMIT 1
""",
(axis_id, score, score),
)
row = cursor.fetchone()
return row[0] if row else None
def _resolve_axis_score_for_label(
cursor: sqlite3.Cursor,
*,
world_id: str,
axis_name: str,
axis_label: str,
) -> float | None:
"""Resolve axis label to a representative numeric score."""
cursor.execute(
"""
SELECT av.min_score, av.max_score
FROM axis_value av
JOIN axis a ON a.id = av.axis_id
WHERE a.world_id = ? AND a.name = ? AND av.value = ?
LIMIT 1
""",
(world_id, axis_name, axis_label),
)
row = cursor.fetchone()
if not row:
return None
min_score = float(row[0]) if row[0] is not None else None
max_score = float(row[1]) if row[1] is not None else None
if min_score is not None and max_score is not None:
return (min_score + max_score) / 2.0
if min_score is not None:
return min_score
if max_score is not None:
return max_score
return DEFAULT_AXIS_SCORE
def _flatten_entity_axis_labels(entity_state: dict[str, Any]) -> dict[str, str]:
"""Flatten entity payload labels into ``axis_name -> label`` mapping."""
labels: dict[str, str] = {}
for group in ("character", "occupation"):
group_payload = entity_state.get(group)
if isinstance(group_payload, dict):
for axis_name, axis_value in group_payload.items():
if isinstance(axis_value, str) and axis_value.strip():
labels[str(axis_name)] = axis_value.strip()
axes_payload = entity_state.get("axes")
if isinstance(axes_payload, dict):
for axis_name, axis_value in axes_payload.items():
if isinstance(axis_value, dict):
label = axis_value.get("label")
if isinstance(label, str) and label.strip():
labels[str(axis_name)] = label.strip()
elif isinstance(axis_value, str) and axis_value.strip():
labels[str(axis_name)] = axis_value.strip()
return labels
[docs]
def apply_entity_state_to_character(
*,
character_id: int,
world_id: str,
entity_state: dict[str, Any],
seed: int | None = None,
event_type_name: str = "entity_profile_seeded",
) -> int | None:
"""Apply entity-state labels as score deltas through the event ledger."""
from mud_server.db.events_repo import apply_axis_event
axis_labels = _flatten_entity_axis_labels(entity_state)
if not axis_labels:
return None
try:
with connection_scope() as conn:
cursor = conn.cursor()
current_scores = {
row["axis_name"]: float(row["axis_score"])
for row in _fetch_character_axis_scores(cursor, character_id, world_id)
}
deltas: dict[str, float] = {}
for axis_name, axis_label in axis_labels.items():
target_score = _resolve_axis_score_for_label(
cursor,
world_id=world_id,
axis_name=axis_name,
axis_label=axis_label,
)
if target_score is None:
continue
old_score = current_scores.get(axis_name, DEFAULT_AXIS_SCORE)
delta = target_score - old_score
if abs(delta) < 1e-9:
continue
deltas[axis_name] = delta
except Exception as exc:
_raise_read_error(
"axis.apply_entity_state_to_character",
exc,
details=f"character_id={character_id}, world_id={world_id!r}",
)
if not deltas:
return None
metadata: dict[str, str] = {
"source": "entity_state_api",
"axis_count": str(len(deltas)),
}
if seed is not None:
metadata["seed"] = str(seed)
return apply_axis_event(
world_id=world_id,
character_id=character_id,
event_type_name=event_type_name,
event_type_description=(
"Initial axis profile generated from external entity-state integration."
),
deltas=deltas,
metadata=metadata,
)
def _fetch_character_axis_scores(
cursor: sqlite3.Cursor,
character_id: int,
world_id: str,
) -> list[dict[str, Any]]:
"""Return character axis score rows joined with axis metadata."""
cursor.execute(
"""
SELECT a.id, a.name, s.axis_score
FROM character_axis_score s
JOIN axis a ON a.id = s.axis_id
WHERE s.character_id = ? AND s.world_id = ?
ORDER BY a.name
""",
(character_id, world_id),
)
return [
{
"axis_id": int(row[0]),
"axis_name": row[1],
"axis_score": float(row[2]),
}
for row in cursor.fetchall()
]
def _seed_character_axis_scores(
cursor: sqlite3.Cursor,
*,
character_id: int,
world_id: str,
default_score: float = DEFAULT_AXIS_SCORE,
) -> None:
"""Seed missing axis score rows for a character."""
cursor.execute(
"""
SELECT id, name
FROM axis
WHERE world_id = ?
ORDER BY name
""",
(world_id,),
)
for axis_id, _axis_name in cursor.fetchall():
cursor.execute(
"""
INSERT OR IGNORE INTO character_axis_score
(character_id, world_id, axis_id, axis_score)
VALUES (?, ?, ?, ?)
""",
(character_id, world_id, int(axis_id), float(default_score)),
)
def _build_character_state_snapshot(
cursor: sqlite3.Cursor,
*,
character_id: int,
world_id: str,
seed: int,
policy_hash: str | None,
) -> dict[str, Any]:
"""Build a canonical snapshot payload from current character axis scores."""
axes_payload: dict[str, Any] = {}
for axis_row in _fetch_character_axis_scores(cursor, character_id, world_id):
label = _resolve_axis_label_for_score(cursor, axis_row["axis_id"], axis_row["axis_score"])
axes_payload[axis_row["axis_name"]] = {
"score": axis_row["axis_score"],
"label": label,
}
return {
"world_id": world_id,
"seed": seed,
"policy_hash": policy_hash,
"axes": axes_payload,
}
def _seed_character_state_snapshot(
cursor: sqlite3.Cursor,
*,
character_id: int,
world_id: str,
seed: int | None = None,
) -> None:
"""Seed base/current state snapshots for a newly created character."""
cursor.execute("SELECT state_seed FROM characters WHERE id = ?", (character_id,))
row = cursor.fetchone()
existing_seed = int(row[0]) if row and row[0] is not None else 0
if existing_seed > 0:
effective_seed = existing_seed
elif seed is not None:
effective_seed = seed
else:
effective_seed = _generate_state_seed()
policy_hash = _get_axis_policy_hash(world_id)
snapshot = _build_character_state_snapshot(
cursor,
character_id=character_id,
world_id=world_id,
seed=effective_seed,
policy_hash=policy_hash,
)
snapshot_json = json.dumps(snapshot, sort_keys=True)
state_updated_at = datetime.now(UTC).isoformat()
cursor.execute(
"""
UPDATE characters
SET base_state_json = COALESCE(base_state_json, ?),
current_state_json = ?,
state_seed = CASE
WHEN state_seed IS NULL OR state_seed = 0 THEN ?
ELSE state_seed
END,
state_version = ?,
state_updated_at = ?
WHERE id = ?
""",
(
snapshot_json,
snapshot_json,
effective_seed,
policy_hash,
state_updated_at,
character_id,
),
)
def _refresh_character_current_snapshot(
cursor: sqlite3.Cursor,
*,
character_id: int,
world_id: str,
seed_increment: int = 1,
) -> None:
"""Refresh current snapshot payload after axis score mutations."""
cursor.execute("SELECT state_seed FROM characters WHERE id = ?", (character_id,))
row = cursor.fetchone()
current_seed = int(row[0]) if row and row[0] is not None else 0
new_seed = current_seed + seed_increment
policy_hash = _get_axis_policy_hash(world_id)
snapshot = _build_character_state_snapshot(
cursor,
character_id=character_id,
world_id=world_id,
seed=new_seed,
policy_hash=policy_hash,
)
snapshot_json = json.dumps(snapshot, sort_keys=True)
state_updated_at = datetime.now(UTC).isoformat()
cursor.execute(
"""
UPDATE characters
SET current_state_json = ?,
state_seed = ?,
state_version = ?,
state_updated_at = ?
WHERE id = ?
""",
(
snapshot_json,
new_seed,
policy_hash,
state_updated_at,
character_id,
),
)
[docs]
def get_character_axis_state(character_id: int) -> dict[str, Any] | None:
"""Return axis score + snapshot payload for one character."""
try:
with connection_scope() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT id,
world_id,
base_state_json,
current_state_json,
state_seed,
state_version,
state_updated_at
FROM characters
WHERE id = ?
""",
(character_id,),
)
row = cursor.fetchone()
if not row:
return None
world_id = row[1]
base_state_json = row[2]
current_state_json = row[3]
state_seed = row[4]
state_version = row[5]
state_updated_at = row[6]
def _safe_load(payload: str | None) -> dict[str, Any] | None:
if not payload:
return None
try:
return cast(dict[str, Any], json.loads(payload))
except json.JSONDecodeError:
return None
axes = []
for axis_row in _fetch_character_axis_scores(cursor, character_id, world_id):
label = _resolve_axis_label_for_score(
cursor, axis_row["axis_id"], axis_row["axis_score"]
)
axes.append(
{
"axis_id": axis_row["axis_id"],
"axis_name": axis_row["axis_name"],
"axis_score": axis_row["axis_score"],
"axis_label": label,
}
)
return {
"character_id": int(row[0]),
"world_id": world_id,
"state_seed": state_seed,
"state_version": state_version,
"state_updated_at": state_updated_at,
"base_state": _safe_load(base_state_json),
"current_state": _safe_load(current_state_json),
"axes": axes,
}
except Exception as exc:
_raise_read_error(
"axis.get_character_axis_state",
exc,
details=f"character_id={character_id}",
)