Source code for indra_network_search.data_models.__init__

"""
This file contains data models for queries, results and arguments to algorithm
functions.
"""
# todo:
#  - Use constr(to_lower=True) in appropriate places to enforce lowercase:
#     + node_blacklist
#     + allowed_ns
#     + stmt_filter (allowed statement types)
#  - Use constr(min_length=N) to enforce that str fields are not empty
#  - Figure out how to use conlist and other con* enforcers for e.g.:
#     + Enforce hashes to be int and/or str
#     + Lowercase for string filters
#  - Figure out how to do "at least one of" filters. See:
#    https://github.com/samuelcolvin/pydantic/issues/506
#    Related: Check if it's possible to apply a setting that can be set on
#    creation to allow different checks, e.g. allow either of:
#          1) source XOR target
#          2) source AND target
#  - In FilterOptions, set overall weighted based on values of weighted
#    context weighted. See here for more info:
#    https://stackoverflow.com/q/54023782/10478812
import logging
from collections import Counter
from typing import Optional, List, Union, Callable, Tuple, Set, Dict, Iterable
from networkx import DiGraph

from pydantic import BaseModel, validator, Extra, constr, conint, confloat, \
    HttpUrl, conlist

from indra.explanation.pathfinding.util import EdgeFilter
from depmap_analysis.network_functions.net_functions import SIGNS_TO_INT_SIGN

from indra_network_search.rest_util import (
    get_query_hash,
    is_weighted,
    is_context_weighted,
    StrNode,
)

try:
    # Py 3.8+
    from typing import Literal
except ImportError:
    # Py 3.7-
    from typing_extensions import Literal

__all__ = [
    "NetworkSearchQuery",
    "SubgraphRestQuery",
    "MultiInteractorsRestQuery",
    "ApiOptions",
    "ShortestSimplePathOptions",
    "BreadthFirstSearchOptions",
    "DijkstraOptions",
    "SharedInteractorsOptions",
    "OntologyOptions",
    "MultiInteractorsOptions",
    "Node",
    "StmtData",
    "EdgeData",
    "EdgeDataByHash",
    "Path",
    "PathResultData",
    "OntologyResults",
    "SharedInteractorsResults",
    "Results",
    "FilterOptions",
    "SubgraphOptions",
    "SubgraphResults",
    "MultiInteractorsResults",
    "DEFAULT_TIMEOUT",
    "WEIGHT_NAME_MAPPING",
    "basemodels_equal",
    "basemodel_in_iterable",
    "StmtTypeSupport",
]


logger = logging.getLogger(__name__)


# Set defaults
DEFAULT_TIMEOUT = 30
WEIGHT_NAME_MAPPING = {
    "belief": "weight",
    "context": "context_weight",
    "z_score": "corr_weight",
    "unweighted": None,
}


# Models for API options and filtering options
[docs]class ApiOptions(BaseModel): """Options that determine API behaviour""" sign: Optional[int] = None fplx_expand: Optional[bool] = False user_timeout: Optional[Union[float, bool]] = False two_way: Optional[bool] = False shared_regulators: Optional[bool] = False format: Optional[str] = "json"
[docs]class FilterOptions(BaseModel): """Options for filtering out nodes or edges""" stmt_filter: List[constr(to_lower=True)] = [] allowed_ns: List[constr(to_lower=True)] = [] node_blacklist: List[str] = [] path_length: Optional[int] = None belief_cutoff: float = 0.0 curated_db_only: bool = False max_paths: int = 50 cull_best_node: Optional[int] = None weighted: Optional[Literal["weight", "context_weight", "corr_weight"]] = None context_weighted: bool = False overall_weighted: bool = False
[docs] def no_filters(self) -> bool: """Return True if all filter options are set to defaults""" return ( len(self.stmt_filter) == 0 and len(self.allowed_ns) == 0 and len(self.node_blacklist) == 0 and self.path_length is None and self.belief_cutoff == 0.0 and self.curated_db_only is False )
[docs] def no_stmt_filters(self): """Return True if the stmt filter options allow all statements""" return ( self.belief_cutoff == 0.0 and len(self.stmt_filter) == 0 and self.curated_db_only is False )
[docs] def no_node_filters(self): """Return True if the node filter options allow all nodes""" return len(self.node_blacklist) == 0 and len(self.allowed_ns) == 0
[docs]class NetworkSearchQuery(BaseModel): """The query model for network searches""" source: constr(strip_whitespace=True) = "" target: constr(strip_whitespace=True) = "" stmt_filter: List[constr(to_lower=True, strip_whitespace=True)] = [] filter_curated: bool = True allowed_ns: List[constr(to_lower=True, strip_whitespace=True)] = [] node_blacklist: List[str] = [] path_length: Optional[int] = None depth_limit: int = 2 sign: Optional[conint(ge=0, le=1)] = None weighted: Literal["belief", "context", "z_score", "unweighted"] = "unweighted" belief_cutoff: Union[float, bool] = 0.0 curated_db_only: bool = False fplx_expand: bool = False k_shortest: int = 50 max_per_node: int = 5 cull_best_node: Optional[int] = None mesh_ids: List[str] = [] strict_mesh_id_filtering: bool = False const_c: int = 1 const_tk: int = 10 user_timeout: Union[float, bool] = DEFAULT_TIMEOUT two_way: bool = False shared_regulators: bool = False terminal_ns: List[str] = [] format: str = "json" # This attribute is probably obsolete now
[docs] @validator("path_length") def is_positive_int(cls, pl: int): """Validate path_length >= 1 if given""" if isinstance(pl, int) and pl < 1: raise ValueError("path_length must be integer > 0") return pl
[docs] @validator("max_per_node") def is_pos_int(cls, mpn: Union[int, bool]): """Validate max_per_node >= 1 if given""" if isinstance(mpn, int) and mpn < 1: raise ValueError("max_per_node must be integer > 0") return mpn
[docs] @validator("cull_best_node") def is_int_gt2(cls, cbn: Optional[int]): """Validate cull_best_node >= 2""" if isinstance(cbn, int) and cbn < 2: raise ValueError("cull_best_node must be integer > 1 if provided") return cbn
class Config: allow_mutation = False # Error for any attempt to change attributes extra = Extra.forbid # Error if non-specified attributes are given
[docs] def is_overall_weighted(self) -> bool: """Return True if this query is weighted This method is used to determine if a weighted search needs to be done using either of shortest_simple_paths and open_dijkstra_search. The exception to self.weighted not being None but still be unweighted is strict mesh id search. """ return is_weighted( weighted=self.weighted in ("belief", "z_score"), mesh_ids=self.mesh_ids, strict_mesh_filtering=self.strict_mesh_id_filtering, )
[docs] def is_context_weighted(self): """Return True if this query is context weighted""" return is_context_weighted( mesh_id_list=self.mesh_ids, strict_filtering=self.strict_mesh_id_filtering )
[docs] def get_hash(self): """Get the corresponding query hash of the query""" return get_query_hash(self.dict(), ignore_keys=["format"])
[docs] def get_int_sign(self) -> Optional[int]: """Return the integer representation of the sign""" if self.sign is None or self.sign == "": return None try: sign = int(self.sign) assert sign in (0, 1) except Exception as exc: logger.info( f"Could not convert {self.sign} of type " f"{type(self.sign)} to int ({str(exc)}), trying " f"SIGNS mapping" ) sign = SIGNS_TO_INT_SIGN.get(self.sign) return sign
[docs] def get_filter_options(self) -> FilterOptions: """Returns the filter options""" return FilterOptions( stmt_filter=self.stmt_filter, allowed_ns=self.allowed_ns, node_blacklist=self.node_blacklist, path_length=self.path_length, belief_cutoff=self.belief_cutoff, curated_db_only=self.curated_db_only, max_paths=self.k_shortest, cull_best_node=self.cull_best_node, overall_weighted=self.is_overall_weighted(), weighted=WEIGHT_NAME_MAPPING.get(self.weighted), context_weighted=is_context_weighted( mesh_id_list=self.mesh_ids, strict_filtering=self.strict_mesh_id_filtering, ), )
# Models for the run options # Todo: # 1. instead of manually setting defaults here, use introspection of # function and look up functions default: # >>> def func(par: int = 0): # ... return par # >>> import inspect # >>> func_pars = inspect.signature(func).parameters # >>> arg = func_pars['par'] # >>> arg.default # 2. For "not-None" defaults: set value to default if None is provided: # https://stackoverflow.com/q/63616798/10478812 # Good for e.g. max_paths
[docs]class ShortestSimplePathOptions(BaseModel): """Arguments for indra.explanation.pathfinding.shortest_simple_paths""" source: Union[str, Tuple[str, int]] target: Union[str, Tuple[str, int]] weight: Optional[str] = None ignore_nodes: Optional[Set[str]] = None ignore_edges: Optional[Set[Tuple[str, str]]] = None hashes: Optional[List[int]] = None ref_counts_function: Optional[Callable] = None strict_mesh_id_filtering: Optional[bool] = False const_c: Optional[int] = 1 const_tk: Optional[int] = 10
[docs]class BreadthFirstSearchOptions(BaseModel): """Arguments for indra.explanation.pathfinding.bfs_search""" source_node: Union[str, Tuple[str, int]] reverse: Optional[bool] = False depth_limit: Optional[int] = 2 path_limit: Optional[int] = None max_per_node: Optional[int] = 5 node_filter: Optional[List[str]] = None node_blacklist: Optional[Set[str]] = None terminal_ns: Optional[List[str]] = None sign: Optional[int] = None max_memory: Optional[int] = int(2 ** 29) hashes: Optional[List[int]] = None allow_edge: Optional[Callable[[DiGraph, StrNode, StrNode], bool]] = None edge_filter: Optional[EdgeFilter] = None strict_mesh_id_filtering: Optional[bool] = False
[docs]class DijkstraOptions(BaseModel): """Arguments for open_dijkstra_search""" start: Union[str, Tuple[str, int]] reverse: Optional[bool] = False path_limit: Optional[int] = None # node_filter: Optional[List[str]] = None # Currently not implemented hashes: Optional[List[int]] = None ignore_nodes: Optional[List[str]] = None ignore_edges: Optional[List[Tuple[str, str]]] = None terminal_ns: Optional[List[str]] = None weight: Optional[str] = None ref_counts_function: Optional[Callable] = None const_c: Optional[int] = 1 const_tk: Optional[int] = 10
[docs]class SharedInteractorsOptions(BaseModel): """Arguments for indra_network_search.pathfinding.shared_interactors""" source: StrNode target: StrNode allowed_ns: Optional[List[str]] = None stmt_types: Optional[List[str]] = None source_filter: Optional[List[str]] = None max_results: Optional[int] = 50 regulators: Optional[bool] = False sign: Optional[int] = None
[docs]class OntologyOptions(BaseModel): """Arguments for indra_network_search.pathfinding.shared_parents""" source_ns: str source_id: str target_ns: str target_id: str max_paths: int = 50 immediate_only: Optional[bool] = False is_a_part_of: Optional[Set[str]] = None
[docs]class MultiInteractorsOptions(BaseModel): """Multi interactors options""" nodes: List[str] downstream: bool allowed_ns: Optional[List[str]] = None stmt_types: Optional[List[str]] = None source_filter: Optional[List[str]] = None max_results: int = 50 hash_blacklist: Optional[Set[int]] = None node_blacklist: Optional[List[str]] = None belief_cutoff: float = 0.0 curated_db_only: bool = False
# Models and sub-models for the Results
[docs]class Node(BaseModel): """Data for a node""" name: Optional[constr(min_length=1)] namespace: constr(min_length=1) identifier: constr(min_length=1) lookup: Optional[constr(min_length=1)] sign: Optional[conint(ge=0, le=1)]
[docs] def get_unsigned_node(self): """Get unsigned version of this node instance""" return self.__class__(**self.dict(exclude={"sign"}, exclude_defaults=True))
[docs] def signed_node_tuple(self) -> Tuple[str, int]: """Get a signed node tuple of node name and node sign Returns ------- : A name, sign tuple Raises ------ TypeError If sign is not defined, a TypeError """ if self.sign is None: raise TypeError( "Node is unsigned, unable to produce a signed " "node tuple" ) return self.name, self.sign
[docs]class StmtData(BaseModel): """Data for one statement supporting an edge""" stmt_type: str evidence_count: conint(ge=1) stmt_hash: Union[int, HttpUrl] source_counts: Dict[str, int] belief: confloat(ge=0.0, le=1.0) curated: bool english: str weight: Optional[float] = None residue: Optional[str] = "" position: Optional[str] = "" initial_sign: Optional[conint(ge=0, le=1)] = None db_url_hash: str # Linkout to hash-level
[docs]class StmtTypeSupport(BaseModel): """Data per statement type""" stmt_type: str source_counts: Dict[str, int] = {} statements: List[StmtData]
[docs] def set_source_counts(self): """Updates the source count field from the set statement data""" self.source_counts = sum( [Counter(**sd.source_counts) for sd in self.statements], Counter() )
[docs]class EdgeData(BaseModel): """Data for one single edge""" edge: List[Node] # Edge supported by statements statements: Dict[str, StmtTypeSupport] # key by stmt_type belief: confloat(ge=0, le=1) # Aggregated belief weight: confloat(ge=0) # Weight corresponding to aggregated belief weight context_weight: Union[ str, confloat(gt=0), Literal["N/A"] ] = "N/A" # Set for context z_score: Optional[float] = None # z-score corr_weight: Optional[confloat(gt=0.0)] = None # Weight from z-score sign: Optional[conint(ge=0, le=1)] # Used for signed paths db_url_edge: str # Linkout to subj-obj level source_counts: Dict[str, int] = {}
[docs] def is_empty(self) -> bool: """Return True if len(statements) == 0""" return len(self.statements) == 0
[docs] def set_source_counts(self): """Updates the source count from the contained data in self.statements""" self.source_counts = sum( [Counter(**sts.source_counts) for sts in self.statements.values()], Counter(), )
[docs]class EdgeDataByHash(BaseModel): """Data for one single edge, with data keyed by hash""" edge: List[Node] stmts: Dict[int, StmtData] # Hash remain as int for JSON belief: float weight: float db_url_edge: str # Linkout to subj-obj level url_by_type: Dict[str, str] # Linkout per statement type
# sign: Optional[int] # Used for signed paths # context_weight: Union[str, float] = 'N/A' # Set for context search
[docs]class Path(BaseModel): """Results for a single path""" # The entries are assumed to be co-ordered # path = [a, b, c] # edge_data = [EdgeData(a, b), EdgeData(b, c)] path: List[Node] # Contains the path edge_data: List[EdgeData] # Contains supporting data, same order as path
[docs] def is_empty(self) -> bool: """Return True if len(path) == 0 or len(edge_data) == 0""" return len(self.path) == 0 or len(self.edge_data) == 0
[docs]class PathResultData(BaseModel): """Results for any of the path algorithms""" # Results for bfs_search, shortest_simple_paths and open_dijkstra_search # It is assumed that at least one of source or target will be set source: Optional[Node] = None target: Optional[Node] = None paths: Dict[int, List[Path]] # keyed by node count
[docs] def is_empty(self) -> bool: """Return True if paths list is empty""" return len(self.paths) == 0
[docs]class OntologyResults(BaseModel): """Results for shared_parents""" source: Node target: Node parents: List[Node]
[docs] def is_empty(self) -> bool: """Return True if parents list is empty""" return len(self.parents) == 0
[docs]class SharedInteractorsResults(BaseModel): """Results for shared targets and shared regulators""" # s->x; t->x source_data: List[EdgeData] target_data: List[EdgeData] downstream: bool
[docs] def is_empty(self): """Return True if both source and target data is empty""" return len(self.source_data) == 0 and len(self.target_data) == 0
[docs]class SubgraphResults(BaseModel): """Results for get_subgraph_edges""" input_nodes: List[Node] not_in_graph: List[Node] available_nodes: List[Node] edges: List[EdgeDataByHash]
[docs]class MultiInteractorsResults(BaseModel): """Results post direct_multi_interactors""" targets: List[Node] regulators: List[Node] edge_data: List[EdgeData] = []
[docs]class Results(BaseModel): """The model wrapping all results from the NetworkSearchQuery""" query_hash: str time_limit: float timed_out: bool hashes: List[str] = [] # Cast as string for JavaScript path_results: Optional[PathResultData] = None reverse_path_results: Optional[PathResultData] = None ontology_results: Optional[OntologyResults] = None shared_target_results: Optional[SharedInteractorsResults] = None shared_regulators_results: Optional[SharedInteractorsResults] = None
[docs]class MultiInteractorsRestQuery(BaseModel): """Multi interactors rest query""" nodes: List[str] downstream: bool allowed_ns: Optional[ List[constr(strip_whitespace=True, to_lower=True, min_length=1)] ] = None stmt_types: Optional[ List[constr(strip_whitespace=True, to_lower=True, min_length=1)] ] = None source_filter: Optional[ List[constr(strip_whitespace=True, to_lower=True, min_length=1)] ] = None max_results: int = 50 node_blacklist: Optional[List[str]] = None belief_cutoff: float = 0.0 curated_db_only: bool = False timeout: confloat(ge=5.0, le=120.0) = DEFAULT_TIMEOUT
[docs]class SubgraphRestQuery(BaseModel): """Subgraph query""" nodes: conlist(item_type=Node, min_items=1, max_items=4000)
[docs]class SubgraphOptions(BaseModel): """Argument for indra_network_search.pathfinding.get_subgraph_edges""" nodes: List[Node]
[docs]def basemodels_equal( basemodel: BaseModel, other_basemodel: BaseModel, any_item: bool, exclude: Optional[Set[str]] = None, ) -> bool: """Wrapper to test two basemodels for equality, can exclude irrelevant keys Parameters ---------- basemodel : BaseModel to test against other_basemodel other_basemodel : BaseModel to test against basemodel any_item : If True, use any() when testing collections for equality, otherwise use all(), i.e. the collections must match exactly exclude : A set of field names to exclude from the basemodels Returns ------- : True if the two models are equal """ b1d = basemodel.dict(exclude=exclude) b2d = other_basemodel.dict(exclude=exclude) qual_func = any if any_item else all return qual_func(_equals(b1d[k1], b2d[k2], any_item) for k1, k2 in zip(b1d, b2d))
def _equals( d1: Union[str, int, float, List, Set, Tuple, Dict], d2: Union[str, int, float, List, Set, Tuple, Dict], any_item: bool, ) -> bool: qual_func = any if any_item else all if d1 is None: return d2 is None elif isinstance(d1, (str, int, float)): return d1 == d2 elif isinstance(d1, (list, tuple)): return qual_func(_equals(e1, e2, any_item) for e1, e2 in zip(d1, d2)) elif isinstance(d1, set): return d1 == d2 elif isinstance(d1, dict): return qual_func(_equals(d1[k1], d2[k2], False) for k1, k2 in zip(d1, d2)) else: raise TypeError(f"Unable to do comparison of type {type(d1)}")
[docs]def basemodel_in_iterable( basemodel: BaseModel, iterable: Iterable, any_item: bool, exclude: Optional[Set[str]] = None, ) -> bool: """Test if a basemodel object is part of a collection Parameters ---------- basemodel : A BaseModel to test membership in iterable for iterable : An iterable that contains objects to test for equality with basemodel any_item : If True, use any() when testing collections for equality, otherwise use all(), i.e. the collections must match exactly exclude : A set of field names to exclude from the basemodels Returns ------- : True if basemodel is found in the collection """ return any( [ basemodels_equal( basemodel=basemodel, other_basemodel=ob, any_item=any_item, exclude=exclude, ) for ob in iterable ] )