"""
The IndraNetworkSearch REST API
"""
import logging
from datetime import date
from os import environ
from typing import List, Optional
from depmap_analysis.network_functions.net_functions import MIN_WEIGHT, bio_ontology
from depmap_analysis.util.io_functions import file_opener
from fastapi import BackgroundTasks, FastAPI
from fastapi import Query as RestQuery
from indra.databases import get_identifiers_url
from pydantic import ValidationError
from tqdm import tqdm
from indra_network_search.autocomplete import NodesTrie, Prefixes
from indra_network_search.data_models import (
MultiInteractorsRestQuery,
MultiInteractorsResults,
NetworkSearchQuery,
Node,
Results,
SubgraphRestQuery,
SubgraphResults,
)
from indra_network_search.data_models.rest_models import Health, ServerStatus
from indra_network_search.rest_util import (
check_existence_and_date_s3,
dump_query_json_to_s3,
dump_result_json_to_s3,
load_indra_graph,
)
from indra_network_search.search_api import IndraNetworkSearchAPI
logger = logging.getLogger(__name__)
NAME = "INDRA Network Search"
VERSION = "1.0.0"
app = FastAPI(
title=NAME,
root_path="/api",
version=VERSION,
)
DEBUG = environ.get("API_DEBUG") == "1"
USE_CACHE = environ.get("USE_CACHE") == "1"
HEALTH = Health(status="booting")
STATUS = ServerStatus(status="booting", graph_date="2022-01-11")
network_search_api: IndraNetworkSearchAPI
nsid_trie: NodesTrie
nodes_trie: NodesTrie
[docs]@app.get("/xrefs", response_model=List[List[str]])
def get_xrefs(ns: str, id: str) -> List[List[str]]:
"""Get all cross-refs given a namespace and ID
Parameters
----------
ns :
The namespace of the entity to find cross-refs for
id :
The identifier of the entity to find cross-regs for
Returns
-------
:
A list of tuples containing namespace, identifier, lookup url to
identifiers.org
"""
# Todo: offload util features and capabilities, such as this one, to a new
# UtilApi class
xrefs = bio_ontology.get_mappings(ns=ns, id=id)
xrefs_w_lookup = [[n, i, get_identifiers_url(n, i)] for n, i in xrefs]
return xrefs_w_lookup
[docs]@app.get("/node-name-in-graph", response_model=Optional[Node])
def node_name_in_graph(node_name: str = RestQuery(..., min_length=1, alias="node-name")) -> Optional[Node]:
"""Check if node by provided name (case sensitive) exists in graph
Parameters
----------
node_name :
The name of the node to check
Returns
-------
:
When a match is found, the full information of the node is returned
"""
node = network_search_api.get_node(node_name)
if node:
return node
[docs]@app.get("/node-id-in-graph", response_model=Optional[Node])
def node_id_in_graph(
db_name: str = RestQuery(..., min_length=2, alias="db-name"),
db_id: str = RestQuery(..., min_length=1, alias="db-id"),
) -> Optional[Node]:
"""Check if a node by provided db name and db id exists
Parameters
----------
db_name :
The database name, e.g. hgnc, chebi or up
db_id :
The identifier for the entity in the given database, e.g. 11018
Returns
-------
:
When a match is found, the full information of the node is returned
"""
node = network_search_api.get_node_by_ns_id(db_ns=db_name, db_id=db_id)
if node:
return node
[docs]@app.get("/autocomplete", response_model=Prefixes)
def get_prefix_autocomplete(
prefix: str = RestQuery(..., min_length=1),
max_res: int = RestQuery(100, alias="max-results"),
) -> Prefixes:
"""Get the case-insensitive node names with (ns, id) starting in prefix
Parameters
----------
prefix :
The prefix of a node name to search for. Note: for prefixes of
1 and 2 characters, only exact matches are returned. For 3+
characters, prefix matching is done. If the prefix contains ':',
an namespace:id search is done.
max_res :
The top ranked (by node degree) results will be returned, cut off at
this many results.
Returns
-------
:
A list of tuples of (node name, namespace, identifier)
"""
# Catch very short entity names
if 1 <= len(prefix) <= 2 and ":" not in prefix:
logger.info("Got short node name lookup")
# Loop all combinations of upper and lowercase
if len(prefix) == 1:
nodes = []
upper_match = network_search_api.get_node(prefix.upper())
lower_match = network_search_api.get_node(prefix.lower())
if upper_match:
nodes.append([upper_match.name, upper_match.namespace, upper_match.identifier])
if lower_match:
nodes.append([lower_match.name, lower_match.namespace, lower_match.identifier])
else:
nodes = []
n1 = prefix.upper()
n2 = prefix[0].lower() + prefix.upper()[1]
n3 = prefix[0].upper() + prefix.lower()[1]
n4 = prefix.lower()
for p in [n1, n2, n3, n4]:
m = network_search_api.get_node(p)
if m:
nodes.append([m.name, m.namespace, m.identifier])
# Look up ns:id searches
elif ":" in prefix:
logger.info("Got ns:id prefix check")
nodes = nsid_trie.case_items(prefix=prefix, top_n=max_res)
else:
logger.info("Got name prefix check")
nodes = nodes_trie.case_items(prefix=prefix, top_n=max_res)
logger.info(f"Prefix query resolved with {len(nodes)} suggestions")
return nodes
[docs]@app.get("/health", response_model=Health)
async def health():
"""Returns health status
Returns
-------
Health
"""
logger.info("Got health check")
return HEALTH
[docs]@app.get("/status", response_model=ServerStatus)
async def server_status():
"""Returns the status of the server and some info about the loaded graphs
Returns
-------
:
"""
logger.info("Got status check")
return STATUS
[docs]@app.post("/query", response_model=Results)
def query(search_query: NetworkSearchQuery, background_tasks: BackgroundTasks):
"""Interface with IndraNetworkSearchAPI.handle_query
Parameters
----------
search_query : NetworkSearchQuery
Query to the NetworkSearchQuery
Returns
-------
Results
"""
query_hash = search_query.get_hash()
logger.info(f"Got NetworkSearchQuery #{query_hash}: {search_query.dict()}")
# Check if results are on S3
keys_dict = check_existence_and_date_s3(query_hash=query_hash)
if keys_dict.get("result_json_key"):
logger.info("Found results cached on S3")
results_json = file_opener(keys_dict["result_json_key"])
try:
results = Results(**results_json)
except ValidationError as verr:
logger.error(verr)
logger.info("Result could not be validated, re-running search")
results = network_search_api.handle_query(rest_query=search_query)
logger.info("Uploading results to S3")
background_tasks.add_task(dump_result_json_to_s3, query_hash, results.dict())
background_tasks.add_task(dump_query_json_to_s3, query_hash, search_query.dict())
else:
logger.info("Performing new search")
results = network_search_api.handle_query(rest_query=search_query)
logger.info("Uploading results to S3")
background_tasks.add_task(dump_result_json_to_s3, query_hash, results.dict())
background_tasks.add_task(dump_query_json_to_s3, query_hash, search_query.dict())
return results
@app.post("/multi_interactors", response_model=MultiInteractorsResults)
def multi_interactors(search_query: MultiInteractorsRestQuery):
logger.info(f"Got multi interactors query with {len(search_query.nodes)} nodes")
results = network_search_api.handle_multi_interactors_query(multi_interactors_rest_query=search_query)
logger.info("Multi interactors query resolved")
return results
[docs]@app.post("/subgraph", response_model=SubgraphResults)
def sub_graph(search_query: SubgraphRestQuery):
"""Interface with IndraNetworkSearchAPI.handle_subgraph_query
Parameters
----------
search_query: SubgraphRestQuery
Query to for IndraNetworkSearchAPI.handle_subgraph_query
Returns
-------
SubgraphResults
"""
logger.info(f"Got subgraph query with {len(search_query.nodes)} nodes")
subgraph_results = network_search_api.handle_subgraph_query(subgraph_rest_query=search_query)
logger.info("Subgraph query resolved")
return subgraph_results
@app.on_event("startup")
def startup_event():
global network_search_api, nsid_trie, nodes_trie
# Todo: figure out how to do all the loading async so the server is
# available to respond to health checks while it's loading
# See:
# - https://fastapi.tiangolo.com/advanced/events/#startup-event
# - https://www.starlette.io/events/
if DEBUG:
from indra_network_search.tests.util import (
unsigned_graph,
signed_node_graph,
)
dir_graph = unsigned_graph
sign_node_graph = signed_node_graph
else:
# ToDo The file IO has to be done awaited to make this function async
dir_graph, _, _, sign_node_graph = load_indra_graph(
unsigned_graph=True,
unsigned_multi_graph=False,
sign_node_graph=True,
sign_edge_graph=False,
use_cache=USE_CACHE,
)
try:
assert all(data["weight"] >= MIN_WEIGHT for _, _, data in dir_graph.edges(data=True))
logger.info("Edge belief weights OK")
except AssertionError:
logger.warning(f"Edge weights below {MIN_WEIGHT} detected, resetting to {MIN_WEIGHT}")
# Reset unsigned graph edge weights
for _, _, data in tqdm(dir_graph.edges(data=True), desc="Resetting edge weights"):
if data["weight"] < MIN_WEIGHT:
data["weight"] = MIN_WEIGHT
# Reset signed node graph edge weights
for _, _, data in sign_node_graph.edges(data=True):
if data["weight"] < MIN_WEIGHT:
data["weight"] = MIN_WEIGHT
bio_ontology.initialize()
# Get a Trie for autocomplete
logger.info("Loading Trie structure with unsigned graph nodes")
nodes_trie = NodesTrie.from_node_names(graph=dir_graph)
nsid_trie = NodesTrie.from_node_ns_id(graph=dir_graph)
# Set numbers for server status
STATUS.unsigned_nodes = len(dir_graph.nodes)
STATUS.unsigned_edges = len(dir_graph.edges)
STATUS.signed_nodes = len(sign_node_graph.nodes)
STATUS.signed_edges = len(sign_node_graph.edges)
dt = dir_graph.graph.get("date")
if dt:
STATUS.graph_date = date.fromisoformat(dt)
# Setup search API
logger.info("Setting up IndraNetworkSearchAPI with signed and unsigned graphs")
network_search_api = IndraNetworkSearchAPI(unsigned_graph=dir_graph, signed_node_graph=sign_node_graph)
logger.info("Service is available")
STATUS.status = "available"
HEALTH.status = "available"