Source code for indra_network_search.result_handler

"""Handles the aggregation of results from the IndraNetworkSearchAPI

The result manager deals with things like:

    - Stopping path iteration when timeout is reached
    - Keeping count of number of paths returned
    - Filtering results when it's not done in the algorithm

"""
import logging
from datetime import datetime, timedelta
from itertools import product
from typing import (
    Any,
    Dict,
    Generator,
    Iterable,
    Iterator,
    List,
    Optional,
    Set,
    Tuple,
    Union,
)

from indra.databases.identifiers import get_identifiers_url
from indra.explanation.pathfinding import (
    bfs_search,
    open_dijkstra_search,
    shortest_simple_paths,
)
from networkx import DiGraph, NetworkXNoPath
from pydantic import BaseModel, ValidationError

from indra_network_search.data_models import *
from indra_network_search.pathfinding import *
from indra_network_search.rest_util import StrNode

__all__ = [
    "ResultManager",
    "DijkstraResultManager",
    "ShortestSimplePathsResultManager",
    "BreadthFirstSearchResultManager",
    "SharedInteractorsResultManager",
    "OntologyResultManager",
    "SubgraphResultManager",
    "MultiInteractorsResultManager",
    "alg_manager_mapping",
    "DB_URL_HASH",
    "DB_URL_EDGE",
]


logger = logging.getLogger(__name__)


DB_URL_HASH = "https://db.indra.bio/statements/from_hash/{stmt_hash}?format=html"
DB_URL_EDGE = (
    "https://db.indra.bio/statements/from_agents?subject="
    "{subj_id}@{subj_ns}&object={obj_id}@"
    "{obj_ns}&ev_limit={ev_limit}&format=html"
)


class ResultManager:
    # Todo: this class is just a parent class for results, we might also
    #  need a wrapper class that manages all the results, analogous to
    #  query vs query_handler
    alg_name: str = NotImplemented
    filter_input_node: bool = NotImplemented

    def __init__(
        self,
        path_generator: Union[Generator, Iterator, Iterable],
        graph: DiGraph,
        filter_options: FilterOptions,
        input_nodes: List[Union[StrNode, Node]],
        timeout: Optional[float] = DEFAULT_TIMEOUT,
    ):
        self.path_gen: Union[Generator, Iterator, Iterable] = path_generator
        self.start_time: Optional[datetime] = None  # Start when looping paths
        self.timeout = timeout
        self.timed_out = False
        # Remove used filters per algorithm
        self.filter_options: FilterOptions = self._remove_used_filters(filter_options)
        self._graph: DiGraph = graph
        self.input_nodes: List[Union[StrNode, Node]] = input_nodes
        # Set for access in this class, only used in UIResultManager
        self._hash_blacklist: Set[int] = set()

    def _pass_node(self, node: Node) -> bool:
        """Pass an individual node based on node data"""
        raise NotImplementedError

    def _pass_stmt(self, stmt_dict: Dict[str, Union[str, int, float, Dict[str, int]]]) -> bool:
        """Pass an individual statement based statement dict content"""
        # Check:
        # - stmt_type
        # - hash_blacklist
        # - belief
        # - curated db
        # Order the checks by likelihood of being applied

        # Skip checking fplx edges as they don't have int hashes
        if (
            self._hash_blacklist
            and stmt_dict["stmt_type"].lower() != "fplx"
            and int(stmt_dict["stmt_hash"]) in self._hash_blacklist
        ):
            return False

        if self.filter_options.stmt_filter and stmt_dict["stmt_type"].lower() not in self.filter_options.stmt_filter:
            return False

        if self.filter_options.belief_cutoff > 0.0 and self.filter_options.belief_cutoff > stmt_dict["belief"]:
            return False

        if self.filter_options.curated_db_only and not stmt_dict["curated"]:
            return False

        return True

    @staticmethod
    def _remove_used_filters(filter_options: FilterOptions) -> FilterOptions:
        """Remove filters already applied in algorithm"""
        raise NotImplementedError

    def _get_node(self, node_name: StrNode, apply_filter: bool = True) -> Optional[Node]:
        # Check if node is signed
        if isinstance(node_name, tuple):
            name, sign = node_name
            node_info = {"name": name, "sign": sign}
        else:
            name, sign = node_name, None
            node_info = {"name": name}

        # Check if node exists in graph
        db_ns = self._graph.nodes.get(node_name, {}).get("ns")
        db_id = self._graph.nodes.get(node_name, {}).get("id")
        if db_id is None or db_ns is None:
            return None

        # Add ns/id to data
        node_info["namespace"] = db_ns
        node_info["identifier"] = db_id

        # Create Node
        node = Node(**node_info)

        # Check if we need to filter node; Skip by default if the node
        # belongs to the input nodes
        if not apply_filter or (
            not self.filter_input_node
            and basemodel_in_iterable(
                basemodel=node,
                iterable=self.input_nodes,
                any_item=False,
                exclude={"lookup"},
            )
        ):
            lookup = get_identifiers_url(db_name=db_ns, db_id=db_id) or ""
            node.lookup = lookup
            return node
        # Apply filters if there are any
        elif self.filter_options.no_node_filters() or self._pass_node(node=node):
            lookup = get_identifiers_url(db_name=db_ns, db_id=db_id) or ""
            node.lookup = lookup
            return node

        return None

    def _get_stmt_data(
        self,
        stmt_dict: Dict[str, Union[str, int, float, Dict[str, int]]],
        ev_limit: Optional[int] = None,
    ) -> Union[StmtData, None]:
        """If statement passes filter, return StmtData model"""
        # Only check _pass_stmt if:
        #   - filters are present or
        #   - the hash blacklist contain values
        if self.filter_options.no_stmt_filters() and not self._hash_blacklist:
            pass
        elif not self._pass_stmt(stmt_dict):
            return None

        try:
            if stmt_dict["stmt_type"] == "fplx":
                # stmt_hash == identifiers.org lookup for ontological edges
                url = stmt_dict["stmt_hash"]
            else:
                url = DB_URL_HASH.format(stmt_hash=stmt_dict["stmt_hash"])
                if ev_limit is not None:
                    url += f"&ev_limit={ev_limit}"
            return StmtData(db_url_hash=url, **stmt_dict)
        except ValidationError as err:
            logger.warning(
                f"Validation of statement data failed for "
                f'"{stmt_dict.get("english", "(unknown statement)")}" with '
                f'hash {stmt_dict.get("stmt_hash", "(unknown hash)")}:'
            )
            logger.exception(err)
            return None

    def _get_edge_data(self, a: Union[StrNode, Node], b: Union[StrNode, Node]) -> Union[EdgeData, None]:
        a_node = a if isinstance(a, Node) else self._get_node(a)
        b_node = b if isinstance(b, Node) else self._get_node(b)
        if a_node is None or b_node is None:
            return None
        edge = [a_node, b_node]
        str_edge = (
            (a_node.name, b_node.name)
            if a_node.sign is None
            else (a_node.signed_node_tuple(), b_node.signed_node_tuple())
        )
        ed: Dict[str, Any] = self._graph.edges[str_edge]

        # Create a StmtTypeSupport model
        stmt_dict: Dict[str, StmtTypeSupport] = {}
        for sd in ed["statements"]:
            stmt_data = self._get_stmt_data(stmt_dict=sd)
            if stmt_data:
                try:
                    stmt_dict[stmt_data.stmt_type].statements.append(stmt_data)
                except KeyError:
                    stmt_dict[stmt_data.stmt_type] = StmtTypeSupport(
                        stmt_type=stmt_data.stmt_type, statements=[stmt_data]
                    )

        # If all support was filtered out
        if not stmt_dict:
            return None

        # Set the source_count field for each StmtTypeSupport
        for sts in stmt_dict.values():
            sts.set_source_counts()

        edge_belief = ed["belief"]
        edge_weight = ed["weight"]
        edge_z_sc = ed["z_score"]
        edge_corr_weight = ed["corr_weight"]

        # Get sign and context weight if present
        extra_dict = {}
        if a_node.sign is not None and b_node.sign is not None:
            sign = 1 if a_node.sign != b_node.sign else 0
            extra_dict["sign"] = sign
        if ed.get("context_weight"):
            extra_dict["context_weight"] = ed["context_weight"]

        url: str = DB_URL_EDGE.format(
            subj_id=a_node.identifier,
            subj_ns=a_node.namespace,
            obj_id=b_node.identifier,
            obj_ns=b_node.namespace,
            ev_limit=10,
        )
        edge_data = EdgeData(
            edge=edge,
            statements=stmt_dict,
            belief=edge_belief,
            weight=edge_weight,
            z_score=edge_z_sc,
            corr_weight=edge_corr_weight,
            db_url_edge=url,
            **extra_dict,
        )
        edge_data.set_source_counts()
        return edge_data

    def _get_results(self):
        # Main method for looping the path finding and results assembly
        raise NotImplementedError

    def _time_results(self):
        # This method executes and times the result assembly
        if self.start_time is None:
            self.start_time = datetime.utcnow()
        return self._get_results()

    def get_results(self):
        # Implement for each class
        raise NotImplementedError


class UIResultManager(ResultManager):
    """Parent class for all results that go to the UI"""

    filter_input_node = NotImplemented

    def __init__(
        self,
        path_generator: Union[Generator, Iterator, Iterable],
        graph: DiGraph,
        filter_options: FilterOptions,
        source: Union[Node, StrNode],
        target: Union[Node, StrNode],
        timeout: Optional[float] = DEFAULT_TIMEOUT,
        hash_blacklist: Optional[Set[int]] = None,
    ):
        super().__init__(
            path_generator=path_generator,
            graph=graph,
            filter_options=filter_options,
            input_nodes=[],  # Set in _set_source_target
            timeout=timeout,
        )
        # NOTE: input_nodes is set in _set_source_target in order to allow
        # calling _check_source_target *after* super.__init__() is called
        self._set_source_target(source=source, target=target)
        self._check_source_target()
        self._hash_blacklist: Set[int] = hash_blacklist or set()

    def _set_source_target(self, source: Union[Node, StrNode], target: Union[Node, StrNode]):
        self.source = None
        self.target = None

        # Set source and/or target
        if not source and not target:
            raise ValueError("Must provide at least source or target for UI results")
        if source:
            sn: Node = source if isinstance(source, Node) else self._get_node(source, apply_filter=False)
            self.source = sn
            self.input_nodes.append(sn)

        if target:
            tn: Node = target if isinstance(target, Node) else self._get_node(target, apply_filter=False)
            self.target = tn
            self.input_nodes.append(tn)

    def _check_source_or_target(self):
        # Check that source and target are either of Node or None
        try:
            assert isinstance(self.source, Node) or self.source is None
            assert isinstance(self.target, Node) or self.target is None
        except AssertionError as err:
            raise ValueError(f"Source and target must be None or instance of " f"Node for {self.alg_name}") from err

        # Only one of source and target allowed
        if not (bool(self.source is not None) ^ bool(self.target is not None)):
            raise ValueError(f"Only one of source and target allowed for {self.alg_name}")

    def _check_source_and_target(self):
        try:
            assert isinstance(self.source, Node)
            assert isinstance(self.target, Node)
        except AssertionError as err:
            raise ValueError(
                f"Both source and target must be provided and be " f"instance of Node for {self.alg_name}"
            ) from err

    def _check_source_target(self):
        """Check that source and target are set properly, i.e. not missing"""
        raise NotImplementedError

    def _pass_node(self, node: Node) -> bool:
        raise NotImplementedError

    @staticmethod
    def _remove_used_filters(filter_options: FilterOptions) -> FilterOptions:
        raise NotImplementedError

    def _get_results(self):
        raise NotImplementedError

    def get_results(self) -> BaseModel:
        raise NotImplementedError


class PathResultManager(UIResultManager):
    """Parent class for path result managers"""

    # The only thing needed in the children is defining _pass_node,
    # _pass_stmt, alg_name, _remove_used_filters and _check_source_target

    alg_name = NotImplemented
    filter_input_node = False

    def __init__(
        self,
        path_generator: Union[Generator, Iterable, Iterator],
        graph: DiGraph,
        filter_options: FilterOptions,
        source: Union[Node, StrNode],
        target: Union[Node, StrNode],
        reverse: bool,
        timeout: float = DEFAULT_TIMEOUT,
        hash_blacklist: Optional[Set[int]] = None,
    ):
        super().__init__(
            path_generator=path_generator,
            graph=graph,
            filter_options=filter_options,
            source=source,
            target=target,
            timeout=timeout,
            hash_blacklist=hash_blacklist,
        )

        self.paths: Dict[int, List[Path]] = {}
        self.reverse: bool = reverse

    def _check_source_target(self):
        raise NotImplementedError

    @staticmethod
    def _remove_used_filters(filter_options: FilterOptions) -> FilterOptions:
        raise NotImplementedError

    def _pass_node(self, node: Node) -> bool:
        raise NotImplementedError

    def _build_paths(self):
        paths_built = 0
        prev_path: Optional[List[str]] = None
        culled_nodes: Set[str] = set()
        # Only set "context_weight" if non-strict context search is made
        if self.filter_options.context_weighted:
            assert self.filter_options.weighted == "context_weight"
            weight = self.filter_options.weighted
        # Since context weight is handled above, simply set other weight
        # options according to filer_options.weighted
        else:
            weight = self.filter_options.weighted

        while True:
            if self.timeout and datetime.utcnow() - self.start_time > timedelta(seconds=self.timeout):
                logger.info(f"Timeout reached ({self.timeout} seconds), breaking results loop")
                self.timed_out = True
                break
            if paths_built >= self.filter_options.max_paths:
                logger.info(f"Found all {self.filter_options.max_paths} shortest paths")
                break

            try:
                if self.filter_options.cull_best_node is not None and prev_path is not None:
                    send_values = _get_cull_values(
                        culled_nodes=culled_nodes,
                        cull_best_node=self.filter_options.cull_best_node,
                        prev_path=prev_path,
                        added_paths=paths_built,
                        graph=self._graph,
                        weight=weight,
                    )
                    # Send value affects current yield value, not next one:
                    # See https://stackoverflow.com/a/12638313/10478812
                    path = self.path_gen.send(send_values)
                else:
                    path = next(self.path_gen)

                # Reverse path if it is reversed, e.g. upstream open search
                if self.reverse:
                    path = path[::-1]

            except StopIteration:
                logger.info("Reached StopIteration in PathResultsManager, breaking.")
                break

            if self.filter_options.path_length and not self.filter_options.overall_weighted:
                if len(path) < self.filter_options.path_length:
                    continue
                elif len(path) > self.filter_options.path_length:
                    logger.info(f"Found all paths of length " f"{self.filter_options.path_length}")
                    break
                else:
                    pass

            # Initialize variables for this iteration
            node_path: List[Node] = []
            edge_data_list = []
            filtered_out = False  # Flag for continuing loop
            edge_data = None  # To catch cases when no paths come out

            # Loop edges of path
            for s, o in zip(path[:-1], path[1:]):
                # Get edge data: if None, edge has been filtered out,
                # break and go to next path
                edge_data = self._get_edge_data(a=s, b=o)
                if edge_data is None or edge_data.is_empty():
                    filtered_out = True
                    break

                # Build PathResultData
                edge_data_list.append(edge_data)
                # Add subject node of edge
                node_path.append(edge_data.edge[0])

            # If inner loop was broken or never ran
            if filtered_out or edge_data is None:
                continue

            # Append final node
            node_path.append(edge_data.edge[1])
            assert len(node_path) == len(path)

            # Build data for current path
            path_data = Path(path=node_path, edge_data=edge_data_list)
            try:
                self.paths[len(path)].append(path_data)
            except KeyError:
                self.paths[len(path)] = [path_data]
            paths_built += 1

            # Caution: for reverse open searches, path is reversed here. This
            # doesn't affect _get_cull_values currently, but remember this
            # for future updates in this function
            prev_path = path

    def _get_results(self) -> PathResultData:
        """Returns the result for the associated algorithm"""
        try:
            if len(self.paths) == 0:
                self._build_paths()
            return PathResultData(source=self.source, target=self.target, paths=self.paths)
        except NetworkXNoPath as exc:
            logger.warning(str(exc))
            return PathResultData(paths={})

    def get_results(self) -> PathResultData:
        """Execute the result assembly with the loaded path generator

        Returns
        -------
        :
            Assembled paths with data as a BaseModel
        """
        return self._time_results()


[docs]class DijkstraResultManager(PathResultManager): """Handles results from open_dijkstra_search""" alg_name = open_dijkstra_search.__name__ def __init__( self, path_generator: Union[Generator, Iterable, Iterator], graph: DiGraph, filter_options: FilterOptions, source: Union[Node, StrNode], target: Union[Node, StrNode], reverse: bool, timeout: float = DEFAULT_TIMEOUT, hash_blacklist: Optional[Set[int]] = None, ): super().__init__( path_generator=path_generator, graph=graph, filter_options=filter_options, source=source, target=target, reverse=reverse, timeout=timeout, hash_blacklist=hash_blacklist, ) def _check_source_target(self): self._check_source_or_target() @staticmethod def _remove_used_filters(filter_options: FilterOptions) -> FilterOptions: # Filters already done in algorithm # node_blacklist # terminal_ns <- Not part of FilterOptions currently # cull best nodes <- Not applicable return FilterOptions(**filter_options.dict(exclude={"node_blacklist", "cull_best_node"}, exclude_defaults=True)) def _pass_node(self, node: Node) -> bool: # open_dijkstra_search already checks: # node_blacklist # terminal_ns # # Still need to check: # allowed_ns if node.namespace.lower() not in self.filter_options.allowed_ns: return False return True
[docs]class BreadthFirstSearchResultManager(PathResultManager): """Handles results from bfs_search""" alg_name = bfs_search.__name__ def __init__( self, path_generator: Union[Generator, Iterable, Iterator], graph: DiGraph, filter_options: FilterOptions, source: Union[Node, StrNode], target: Union[Node, StrNode], reverse: bool, timeout: float = DEFAULT_TIMEOUT, ): super().__init__( path_generator=path_generator, graph=graph, filter_options=filter_options, source=source, target=target, reverse=reverse, timeout=timeout, hash_blacklist=None, ) def _check_source_target(self): self._check_source_or_target() @staticmethod def _remove_used_filters(filter_options: FilterOptions) -> FilterOptions: # Filters already done in algorithm # Node filters: # ns filter # node blacklist # path len <-- not really though, BFS stops when paths starts to be # longer than path_len, but also allows paths that are # shorter # terminal ns <-- not in post filtering anyway return FilterOptions( **filter_options.dict( exclude={ "allowed_ns", "node_blacklist", }, exclude_defaults=True, ) ) def _pass_node(self, node: Node) -> bool: # allowed_ns, node_blacklist and terminal_ns are all checked in # bfs_search return True
[docs]class ShortestSimplePathsResultManager(PathResultManager): """Handles results from shortest_simple_paths""" alg_name = shortest_simple_paths.__name__ def __init__( self, path_generator: Union[Generator, Iterable, Iterator], graph: DiGraph, filter_options: FilterOptions, source: Union[Node, StrNode], target: Union[Node, StrNode], timeout: float = DEFAULT_TIMEOUT, hash_blacklist: Optional[Set[int]] = None, ): super().__init__( path_generator=path_generator, graph=graph, filter_options=filter_options, source=source, target=target, reverse=False, timeout=timeout, hash_blacklist=hash_blacklist, ) def _check_source_target(self): self._check_source_and_target() @staticmethod def _remove_used_filters(filter_options: FilterOptions) -> FilterOptions: # Filters already done in algorithm: # # return FilterOptions(**filter_options.dict(exclude={"node_blacklist"}, exclude_defaults=True)) def _pass_node(self, node: Node) -> bool: # Check: # - allowed_ns if node.namespace.lower() not in self.filter_options.allowed_ns: return False return True
[docs]class SharedInteractorsResultManager(UIResultManager): """Handles results from shared_interactors, both up and downstream downstream is True for shared targets and False for shared regulators """ alg_name: str = shared_interactors.__name__ filter_input_node = False def __init__( self, path_generator: Union[Iterable, Iterator, Generator], filter_options: FilterOptions, graph: DiGraph, source: Union[Node, StrNode], target: Union[Node, StrNode], is_targets_query: bool, ): super().__init__( path_generator=path_generator, graph=graph, filter_options=filter_options, source=source, target=target, hash_blacklist=None, ) self._downstream: bool = is_targets_query def _check_source_target(self): self._check_source_and_target() @staticmethod def _remove_used_filters(filter_options: FilterOptions) -> FilterOptions: # Only add stmt data filters return FilterOptions( **filter_options.dict( include={ "stmt_filter", "belief_cutoff", "curated_db_only", }, exclude_defaults=True, ) ) def _pass_node(self, node: Node) -> bool: # allowed_ns, node_blacklist are both check in algorithm return True def _get_results(self) -> SharedInteractorsResults: """Get results for shared_targets and shared_regulators""" source_edges: List[EdgeData] = [] target_edges: List[EdgeData] = [] for (s1, s2), (t1, t2) in self.path_gen: if self.timeout and datetime.utcnow() - self.start_time > timedelta(seconds=self.timeout): logger.info(f"Timeout reached ({self.timeout} seconds), breaking results loop") self.timed_out = True break source_edge = self._get_edge_data(a=s1, b=s2) target_edge = self._get_edge_data(a=t1, b=t2) if source_edge and target_edge: source_edges.append(source_edge) target_edges.append(target_edge) return SharedInteractorsResults( source_data=source_edges, target_data=target_edges, downstream=self._downstream, )
[docs] def get_results(self) -> SharedInteractorsResults: """Execute the result assembly with the loaded generator Returns ------- : Results for shared_interactors as a BaseModel """ return self._time_results()
[docs]class OntologyResultManager(UIResultManager): """Handles results from shared_parents""" alg_name: str = shared_parents.__name__ filter_input_node = False def __init__( self, path_generator: Union[Iterable, Iterator, Generator], graph: DiGraph, filter_options: FilterOptions, source: Union[Node, StrNode], target: Union[Node, StrNode], ): super().__init__( path_generator=path_generator, graph=graph, filter_options=filter_options, source=source, target=target, hash_blacklist=None, ) self._parents: List[Node] = [] def _check_source_target(self): self._check_source_and_target() @staticmethod def _remove_used_filters(filter_options: FilterOptions) -> FilterOptions: # No filters are applied return FilterOptions() def _pass_node(self, node: Node) -> bool: # No filters are applied return True def _get_parents(self): for name, ns, _id, id_url in self.path_gen: if self.timeout and datetime.utcnow() - self.start_time > timedelta(seconds=self.timeout): logger.info(f"Timeout reached ({self.timeout} seconds), " f"breaking results loop") self.timed_out = True break node = Node(name=name, namespace=ns, identifier=_id, lookup=id_url) self._parents.append(node) def _get_results(self) -> OntologyResults: """Get results for shared_parents""" self._get_parents() return OntologyResults(source=self.source, target=self.target, parents=self._parents)
[docs] def get_results(self) -> OntologyResults: """Execute the result assembly with the loaded generator Returns ------- : Results for shared_parents as a BaseModel """ return self._time_results()
def _get_cull_values( culled_nodes: Set[str], cull_best_node: int, prev_path: List[str], added_paths: int, graph: DiGraph, weight: Optional[str] = None, ) -> Tuple[Set[str], Set[str]]: # Caution: prev path could be reversed if the search is e.g. upstream open # search. This function is invariant to order, so this is currently OK if ( added_paths >= cull_best_node and added_paths % cull_best_node == 0 and prev_path is not None and len(prev_path) >= 3 ): degrees = graph.degree(prev_path[1:-1], weight) highest_degree_node = max(degrees, key=lambda x: x[1])[0] culled_nodes.add(highest_degree_node) return culled_nodes, set()
[docs]class SubgraphResultManager(ResultManager): """Handles results from get_subgraph_edges""" alg_name = get_subgraph_edges.__name__ filter_input_node = False def __init__( self, path_generator: Iterator[Tuple[str, str]], graph: DiGraph, filter_options: FilterOptions, original_nodes: List[Node], nodes_in_graph: List[Node], not_in_graph: List[Node], ev_limit: int = 10, timeout: float = MAX_TIMEOUT, ): super().__init__( path_generator=path_generator, graph=graph, filter_options=filter_options, input_nodes=original_nodes, timeout=timeout ) self.edge_dict: Dict[Tuple[str, str], EdgeDataByHash] = {} self._available_nodes: Dict[str, Node] = {n.name: n for n in nodes_in_graph} self._not_in_graph: List[Node] = not_in_graph self._ev_limit = ev_limit def _pass_node(self, node: Node) -> bool: # No filters implemented yet return True def _pass_stmt(self, stmt_dict: Dict[str, Union[str, int, float, Dict[str, int]]]) -> bool: # Overwrite _pass_stmt() from parent to be able to filter out fplx # edges if stmt_dict["stmt_type"].lower() == "fplx": return False return True @staticmethod def _remove_used_filters(filter_options: FilterOptions) -> FilterOptions: # Add fplx as allowed type so that pass stmt gets called and overwrite # _pass_stmt to remove edges with it return FilterOptions(stmt_filter=["fplx"]) def _get_edge_data_by_hash(self, a: Union[str, Node], b: Union[str, Node]) -> Union[EdgeDataByHash, None]: # Get node, return if unidentifiable a_node = a if isinstance(a, Node) else self._get_node(a) b_node = b if isinstance(b, Node) else self._get_node(b) if a_node is None or b_node is None: return None # Add lookup if not present if not a_node.lookup: a_node.lookup = get_identifiers_url(a_node.namespace, a_node.identifier) if not b_node.lookup: b_node.lookup = get_identifiers_url(b_node.namespace, b_node.identifier) # Get stmt data for edge edge = [a_node, b_node] ed: Dict[str, Any] = self._graph.edges[(a_node.name, b_node.name)] stmt_dict: Dict[int, StmtData] = {} # Collect stmt_data by hash for sd in ed["statements"]: stmt_data = self._get_stmt_data(stmt_dict=sd, ev_limit=self._ev_limit) if stmt_data and stmt_data.stmt_hash not in stmt_dict: stmt_dict[stmt_data.stmt_hash] = stmt_data # If all support was filtered out if not stmt_dict: return None # Get edge aggregated belief, weight edge_belief = ed["belief"] edge_weight = ed["weight"] # FixMe: expose z-score and corr_weight here? edge_url = DB_URL_EDGE.format( subj_id=a_node.identifier, subj_ns=a_node.namespace, obj_id=b_node.identifier, obj_ns=b_node.namespace, ev_limit=self._ev_limit, ) edge_url_types = {} for st in stmt_dict.values(): if st.stmt_type not in edge_url_types: edge_url_types[st.stmt_type] = edge_url + f"&type={st.stmt_type}" return EdgeDataByHash( edge=edge, stmts=stmt_dict, belief=edge_belief, weight=edge_weight, db_url_edge=edge_url, url_by_type=edge_url_types, ) def _fill_data(self): """Build EdgeDataByHash for all edges, without duplicates""" logger.info(f"Generating output data for subgraph with " f"{len(self._available_nodes)} eligible nodes") # Loop edges for a, b in self.path_gen: if self.timeout and datetime.utcnow() - self.start_time > timedelta(seconds=self.timeout): logger.info(f"Timeout reached ({self.timeout} seconds), " f"breaking results loop") self.timed_out = True break if self.timeout and datetime.utcnow() - self.start_time > timedelta(seconds=self.timeout): logger.info(f"Timeout reached ({self.timeout} seconds), " f"breaking results loop") self.timed_out = True break edge: Tuple[str, str] = (a, b) if edge not in self.edge_dict: half_edge = ( self._available_nodes[a] if a in self._available_nodes else a, self._available_nodes[b] if b in self._available_nodes else b, ) edge_data: EdgeDataByHash = self._get_edge_data_by_hash(*half_edge) if edge_data: self.edge_dict[edge] = edge_data def _get_results(self) -> SubgraphResults: """Get results for get_subgraph_edges""" if not self.edge_dict and len(self._available_nodes) > 0: self._fill_data() edges: List[EdgeDataByHash] = list(self.edge_dict.values()) return SubgraphResults( available_nodes=list(self._available_nodes.values()), edges=edges, input_nodes=self.input_nodes, not_in_graph=self._not_in_graph, )
[docs] def get_results(self) -> SubgraphResults: """Execute the result assembly with the loaded generator Returns ------- : Results for get_subgraph_edges as a BaseModel """ return self._time_results()
[docs]class MultiInteractorsResultManager(ResultManager): """Handles results from `pathfinding.direct_multi_interactors`""" alg_name: str = direct_multi_interactors.__name__ filter_input_node: bool = False def __init__( self, path_generator: Iterator, graph: DiGraph, input_nodes: List[StrNode], filter_options: FilterOptions, downstream: bool, timeout: Optional[float] = DEFAULT_TIMEOUT, ): super().__init__( path_generator=path_generator, graph=graph, input_nodes=input_nodes, filter_options=filter_options, timeout=timeout, ) self.downstream = downstream self.edge_data_list: Optional[List[EdgeData]] = [] if self.downstream: self.regulators: List[Node] = [self._get_node(node_name=name, apply_filter=False) for name in input_nodes] self.targets: List[Node] = [] else: self.regulators: List[Node] = [] self.targets: List[Node] = [self._get_node(node_name=name, apply_filter=False) for name in input_nodes] def _pass_node(self, node: Node) -> bool: # Node blacklist and allowed ns are checked in direct_multi_interactors return True @staticmethod def _remove_used_filters(filter_options: FilterOptions) -> FilterOptions: # Add stmt data filters return FilterOptions( **filter_options.dict( include={ "stmt_type", "belief_cutoff", "curated_db_only", }, exclude_defaults=True, ) ) def _get_edge_iter(self) -> Iterable[Tuple[Node, Node]]: """Return all edges as (StrNode, StrNode)""" # If downstream, regulators == input nodes input_nodes = self.regulators if self.downstream else self.targets neighbors = [self._get_node(node_name=name, apply_filter=False) for name in self.path_gen] prod_args = (input_nodes, neighbors) if self.downstream else (neighbors, input_nodes) return ((s, o) for s, o in product(*prod_args)) def _loop_edges(self): for s, t in self._get_edge_iter(): if self.timeout and datetime.utcnow() - self.start_time > timedelta(seconds=self.timeout): logger.info(f"Timeout reached ({self.timeout} seconds), " f"breaking results loop") self.timed_out = True break edge_data = self._get_edge_data(a=s, b=t) if edge_data: self.edge_data_list.append(edge_data) if self.edge_data_list: logger.info(f"Added data for {len(self.edge_data_list)} edges") else: logger.info(f"No common {'targets' if self.downstream else 'regulators'} was found for multi interactors") def _get_results(self) -> MultiInteractorsResults: if not self.edge_data_list: self._loop_edges() return MultiInteractorsResults( targets=self.targets, regulators=self.regulators, edge_data=self.edge_data_list, )
[docs] def get_results(self) -> MultiInteractorsResults: """Execute the result assembly with the loaded generator Returns ------- : Results for direct_multi_interactors as a BaseModel """ return self._time_results()
# Map algorithm names to result classes alg_manager_mapping = { shortest_simple_paths.__name__: ShortestSimplePathsResultManager, open_dijkstra_search.__name__: DijkstraResultManager, bfs_search.__name__: BreadthFirstSearchResultManager, "shared_targets": SharedInteractorsResultManager, "shared_regulators": SharedInteractorsResultManager, shared_parents.__name__: OntologyResultManager, get_subgraph_edges.__name__: SubgraphResultManager, direct_multi_interactors.__name__: MultiInteractorsResultManager, }