Source code for calute.extensions.hooks

# Copyright 2025 The EasyDeL/Calute Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Hook system for Calute lifecycle events.

Hooks allow external code to observe and mutate agent execution at
well-defined points. Each hook point has documented semantics:

Hook Points:
    - ``before_tool_call(tool_name, arguments, agent_id) -> arguments``
        Called before a tool is executed. Can modify arguments.
        Return modified arguments dict or None to keep original.

    - ``after_tool_call(tool_name, arguments, result, agent_id) -> result``
        Called after a tool executes. Can transform the result.
        Return modified result or None to keep original.

    - ``tool_result_persist(tool_name, result, agent_id) -> result``
        Called before tool result is stored in conversation history.
        Allows sanitizing/transforming results for persistence.

    - ``bootstrap_files(agent_id) -> list[str]``
        Called during prompt assembly. Returns extra content strings
        to inject into the system prompt.

    - ``on_turn_start(agent_id, messages)``
        Called when a new agent turn begins.

    - ``on_turn_end(agent_id, response)``
        Called when an agent turn completes.

    - ``on_error(agent_id, error)``
        Called when an error occurs during execution.

Execution semantics:
    - Hooks run in registration order.
    - A hook that raises is logged and skipped (does not break the chain).
    - Mutation hooks (before_tool_call, after_tool_call, tool_result_persist)
      pass their return value to the next hook in the chain.
"""

from __future__ import annotations

import logging
import typing as tp

logger = logging.getLogger(__name__)

HookCallback = tp.Callable[..., tp.Any]

HOOK_POINTS = frozenset(
    {
        "before_tool_call",
        "after_tool_call",
        "tool_result_persist",
        "bootstrap_files",
        "on_turn_start",
        "on_turn_end",
        "on_error",
    }
)

_MUTATION_HOOKS = frozenset({"before_tool_call", "after_tool_call", "tool_result_persist"})


[docs]class HookRunner: """Manages registration and execution of lifecycle hooks. ``HookRunner`` maintains an ordered list of callback functions for each defined hook point. Hooks are divided into two categories: * **Mutation hooks** (``before_tool_call``, ``after_tool_call``, ``tool_result_persist``) -- callbacks are chained so each can modify a value that is passed to the next callback. * **Observation hooks** (``bootstrap_files``, ``on_turn_start``, ``on_turn_end``, ``on_error``) -- all callbacks are invoked and their non-``None`` return values are collected into a list. Callbacks that raise exceptions are logged and skipped without breaking the hook chain. Attributes: _hooks: Internal mapping from hook point names to ordered lists of registered callbacks. Example: >>> runner = HookRunner() >>> runner.register("before_tool_call", my_hook_fn) >>> modified_args = runner.run("before_tool_call", ... tool_name="search", arguments={"q": "hello"}, agent_id="a1") """ def __init__(self) -> None: """Initialize the HookRunner with empty callback lists for all hook points. Pre-populates the internal ``_hooks`` dictionary with an empty list for every name defined in :data:`HOOK_POINTS`. """ self._hooks: dict[str, list[HookCallback]] = {name: [] for name in HOOK_POINTS}
[docs] def register(self, hook_point: str, callback: HookCallback) -> None: """Register a hook callback for a specific hook point. Args: hook_point: One of the defined HOOK_POINTS. callback: The callable to invoke at this hook point. Raises: ValueError: If hook_point is not recognized. """ if hook_point not in HOOK_POINTS: raise ValueError(f"Unknown hook point '{hook_point}'. Valid: {sorted(HOOK_POINTS)}") self._hooks[hook_point].append(callback) logger.debug( "Registered hook for '%s': %s", hook_point, callback.__name__ if hasattr(callback, "__name__") else str(callback), )
[docs] def unregister(self, hook_point: str, callback: HookCallback) -> bool: """Remove a previously registered callback from a hook point. Args: hook_point: The hook point name to search in. callback: The exact callable instance to remove (identity match). Returns: ``True`` if the callback was found and removed, ``False`` if the hook point does not exist or the callback was not registered. """ if hook_point not in self._hooks: return False try: self._hooks[hook_point].remove(callback) return True except ValueError: return False
[docs] def clear(self, hook_point: str | None = None) -> None: """Clear all registered callbacks for one or all hook points. Args: hook_point: If provided, only callbacks for this specific hook point are removed. If ``None`` (the default), callbacks for **every** hook point are removed. """ if hook_point: self._hooks[hook_point] = [] else: self._hooks = {name: [] for name in HOOK_POINTS}
[docs] def run(self, hook_point: str, **kwargs) -> tp.Any: """Execute all hooks registered for a hook point. Dispatches to :meth:`_run_mutation` or :meth:`_run_observation` depending on the hook category. For mutation hooks (``before_tool_call``, ``after_tool_call``, ``tool_result_persist``): - The return value from each hook is passed as updated kwargs to the next hook in the chain. - For ``before_tool_call``: the return value replaces ``'arguments'``. - For ``after_tool_call`` / ``tool_result_persist``: the return value replaces ``'result'``. - Returns the final mutated value. For observation hooks (``bootstrap_files``, ``on_turn_start``, etc.): - All callbacks are invoked; return values are collected. - Returns a list of non-``None`` return values. If no callbacks are registered the method returns the relevant default value (``arguments`` or ``result`` from *kwargs*). Args: hook_point: The hook point name to execute. **kwargs: Keyword arguments forwarded to every callback. The exact keys depend on the hook point (see module docstring). Returns: The final mutated value for mutation hooks, or a list of collected results for observation hooks. """ callbacks = self._hooks.get(hook_point, []) if not callbacks: return kwargs.get("arguments") if hook_point == "before_tool_call" else kwargs.get("result") if hook_point in _MUTATION_HOOKS: return self._run_mutation(hook_point, callbacks, **kwargs) else: return self._run_observation(hook_point, callbacks, **kwargs)
def _run_mutation(self, hook_point: str, callbacks: list[HookCallback], **kwargs) -> tp.Any: """Run mutation hooks, chaining non-``None`` return values. Each callback receives the current *kwargs*. When a callback returns a non-``None`` value the corresponding mutable key (``"arguments"`` for ``before_tool_call``, ``"result"`` for others) is updated in *kwargs* before the next callback is invoked. Args: hook_point: The mutation hook point name. callbacks: Ordered list of callbacks to invoke. **kwargs: The keyword arguments passed through the chain. Returns: The final value of the mutated key after all callbacks have run. """ if hook_point == "before_tool_call": mutated_key = "arguments" else: mutated_key = "result" current = kwargs.get(mutated_key) for cb in callbacks: try: ret = cb(**kwargs) if ret is not None: current = ret kwargs[mutated_key] = current except Exception: logger.warning("Hook '%s' raised in %s", hook_point, cb, exc_info=True) return current def _run_observation(self, hook_point: str, callbacks: list[HookCallback], **kwargs) -> list[tp.Any]: """Run observation hooks, collecting non-``None`` return values. All callbacks are invoked regardless of individual return values. Exceptions are logged and do not prevent subsequent callbacks from executing. Args: hook_point: The observation hook point name. callbacks: Ordered list of callbacks to invoke. **kwargs: The keyword arguments forwarded to every callback. Returns: A list of non-``None`` values returned by the callbacks. """ results = [] for cb in callbacks: try: ret = cb(**kwargs) if ret is not None: results.append(ret) except Exception: logger.warning("Hook '%s' raised in %s", hook_point, cb, exc_info=True) return results
[docs] def has_hooks(self, hook_point: str) -> bool: """Check whether any callbacks are registered for a hook point. Args: hook_point: The hook point name to query. Returns: ``True`` if at least one callback is registered, ``False`` otherwise (including when *hook_point* is not a recognized name). """ return bool(self._hooks.get(hook_point))