# -*- coding: utf-8 -*-
import logging
from typing import Iterable, Mapping, Set, Tuple
import itertools as itt
from networkx import DiGraph, Graph
from pybel import BELGraph
from pybel.constants import (
CAUSAL_DECREASE_RELATIONS, CAUSAL_INCREASE_RELATIONS, CORRELATIVE_RELATIONS,
NEGATIVE_CORRELATION, POSITIVE_CORRELATION, RELATION,
)
from pybel.dsl import BaseEntity
from pybel.struct.utils import update_node_helper
from ..selection import get_causal_subgraph
from ..summary.contradictions import relation_set_has_contradictions
__all__ = [
'get_contradiction_summary',
'get_regulatory_pairs',
'get_chaotic_pairs',
'get_dampened_pairs',
'get_correlation_graph',
'get_correlation_triangles',
'get_separate_unstable_correlation_triples',
'get_mutually_unstable_correlation_triples',
'get_triangles',
'jens_transformation_alpha',
'jens_transformation_beta',
'get_jens_unstable',
'get_increase_mismatch_triplets',
'get_decrease_mismatch_triplets',
'get_chaotic_triplets',
'get_dampened_triplets',
'summarize_stability',
]
log = logging.getLogger(__name__)
NodePair = Tuple[BaseEntity, BaseEntity]
NodeTriple = Tuple[BaseEntity, BaseEntity, BaseEntity]
SetOfNodePairs = Set[NodePair]
SetOfNodeTriples = Set[NodeTriple]
[docs]def get_contradiction_summary(graph: BELGraph) -> Iterable[Tuple[BaseEntity, BaseEntity, str]]:
"""Yield triplets of (source node, target node, set of relations) for (source node, target node) pairs
that have multiple, contradictory relations.
"""
for u, v in set(graph.edges()):
relations = {data[RELATION] for data in graph[u][v].values()}
if relation_set_has_contradictions(relations):
yield u, v, relations
[docs]def get_regulatory_pairs(graph: BELGraph) -> Set[NodePair]:
"""Find pairs of nodes that have mutual causal edges that are regulating each other such that ``A -> B`` and
``B -| A``.
:return: A set of pairs of nodes with mutual causal edges
"""
cg = get_causal_subgraph(graph)
results = set()
for u, v, d in cg.edges(data=True):
if d[RELATION] not in CAUSAL_INCREASE_RELATIONS:
continue
if cg.has_edge(v, u) and any(dd[RELATION] in CAUSAL_DECREASE_RELATIONS for dd in cg[v][u].values()):
results.add((u, v))
return results
[docs]def get_chaotic_pairs(graph: BELGraph) -> SetOfNodePairs:
"""Find pairs of nodes that have mutual causal edges that are increasing each other such that ``A -> B`` and
``B -> A``.
:return: A set of pairs of nodes with mutual causal edges
"""
cg = get_causal_subgraph(graph)
results = set()
for u, v, d in cg.edges(data=True):
if d[RELATION] not in CAUSAL_INCREASE_RELATIONS:
continue
if cg.has_edge(v, u) and any(dd[RELATION] in CAUSAL_INCREASE_RELATIONS for dd in cg[v][u].values()):
results.add(tuple(sorted([u, v], key=str)))
return results
[docs]def get_dampened_pairs(graph: BELGraph) -> SetOfNodePairs:
"""Find pairs of nodes that have mutual causal edges that are decreasing each other such that ``A -| B`` and
``B -| A``.
:return: A set of pairs of nodes with mutual causal edges
"""
cg = get_causal_subgraph(graph)
results = set()
for u, v, d in cg.edges(data=True):
if d[RELATION] not in CAUSAL_DECREASE_RELATIONS:
continue
if cg.has_edge(v, u) and any(dd[RELATION] in CAUSAL_DECREASE_RELATIONS for dd in cg[v][u].values()):
results.add(tuple(sorted([u, v], key=str)))
return results
[docs]def get_correlation_graph(graph: BELGraph) -> Graph:
"""Extract an undirected graph of only correlative relationships."""
result = Graph()
for u, v, d in graph.edges(data=True):
if d[RELATION] not in CORRELATIVE_RELATIONS:
continue
if not result.has_edge(u, v):
result.add_edge(u, v, **{d[RELATION]: True})
elif d[RELATION] not in result[u][v]:
log.log(5, 'broken correlation relation for %s, %s', u, v)
result[u][v][d[RELATION]] = True
result[v][u][d[RELATION]] = True
return result
[docs]def get_correlation_triangles(graph: BELGraph) -> SetOfNodeTriples:
"""Return a set of all triangles pointed by the given node."""
return {
tuple(sorted([n, u, v], key=str))
for n in graph
for u, v in itt.combinations(graph[n], 2)
if graph.has_edge(u, v)
}
[docs]def get_triangles(graph: DiGraph) -> SetOfNodeTriples:
"""Get a set of triples representing the 3-cycles from a directional graph.
Each 3-cycle is returned once, with nodes in sorted order.
"""
return {
tuple(sorted([a, b, c], key=str))
for a, b in graph.edges()
for c in graph.successors(b)
if graph.has_edge(c, a)
}
[docs]def get_separate_unstable_correlation_triples(graph: BELGraph) -> Iterable[NodeTriple]:
"""Yield all triples of nodes A, B, C such that ``A pos B``, ``A pos C``, and ``B neg C``.
:return: An iterator over triples of unstable graphs, where the second two are negative
"""
cg = get_correlation_graph(graph)
for a, b, c in get_correlation_triangles(cg):
if POSITIVE_CORRELATION in cg[a][b] and POSITIVE_CORRELATION in cg[b][c] and NEGATIVE_CORRELATION in \
cg[a][c]:
yield b, a, c
if POSITIVE_CORRELATION in cg[a][b] and NEGATIVE_CORRELATION in cg[b][c] and POSITIVE_CORRELATION in \
cg[a][c]:
yield a, b, c
if NEGATIVE_CORRELATION in cg[a][b] and POSITIVE_CORRELATION in cg[b][c] and POSITIVE_CORRELATION in \
cg[a][c]:
yield c, a, b
[docs]def get_mutually_unstable_correlation_triples(graph: BELGraph) -> Iterable[NodeTriple]:
"""Yield triples of nodes (A, B, C) such that ``A neg B``, ``B neg C``, and ``C neg A``."""
cg = get_correlation_graph(graph)
for a, b, c in get_correlation_triangles(cg):
if all(NEGATIVE_CORRELATION in x for x in (cg[a][b], cg[b][c], cg[a][c])):
yield a, b, c
[docs]def get_jens_unstable(graph: BELGraph) -> Iterable[NodeTriple]:
"""Yield triples of nodes (A, B, C) where ``A -> B``, ``A -| C``, and ``C positiveCorrelation A``.
Calculated efficiently using the Jens Transformation.
"""
r = jens_transformation_alpha(graph)
return get_triangles(r)
def _get_mismatch_triplets_helper(graph: BELGraph, relation_set: Set[str]) -> Iterable[NodeTriple]:
for node in graph:
children = {
target
for _, target, data in graph.out_edges(node, data=True)
if data[RELATION] in relation_set
}
for a, b in itt.combinations(children, 2):
if b not in graph[a]:
continue
if any(d[RELATION] == NEGATIVE_CORRELATION for d in graph[a][b].values()):
yield node, a, b
[docs]def get_increase_mismatch_triplets(graph: BELGraph) -> Iterable[NodeTriple]:
"""Yield triples of nodes (A, B, C) where ``A -> B``, ``A -> C``, and ``C negativeCorrelation A``."""
return _get_mismatch_triplets_helper(graph, CAUSAL_INCREASE_RELATIONS)
[docs]def get_decrease_mismatch_triplets(graph: BELGraph) -> Iterable[NodeTriple]:
"""Yield triples of nodes (A, B, C) where ``A -| B``, ``A -| C``, and ``C negativeCorrelation A``."""
return _get_mismatch_triplets_helper(graph, CAUSAL_DECREASE_RELATIONS)
def _get_disregulated_triplets_helper(graph: BELGraph, relation_set: Set[str]) -> Iterable[NodeTriple]:
result = DiGraph()
for u, v, d in graph.edges(data=True):
if d[RELATION] in relation_set:
result.add_edge(u, v)
update_node_helper(graph, result)
for a, b, c in get_triangles(result):
if a == b == c:
continue
yield a, b, c
[docs]def get_chaotic_triplets(graph: BELGraph) -> Iterable[NodeTriple]:
"""Yield triples of nodes (A, B, C) that mutually increase each other, such as when ``A -> B``, ``B -> C``, and
``C -> A``.
"""
return _get_disregulated_triplets_helper(graph, CAUSAL_INCREASE_RELATIONS)
[docs]def get_dampened_triplets(graph: BELGraph) -> Iterable[NodeTriple]:
"""Yield triples of nodes (A, B, C) that mutually decreases each other, such as when ``A -| B``,
``B -| C``, and ``C -| A``.
"""
return _get_disregulated_triplets_helper(graph, CAUSAL_DECREASE_RELATIONS)
[docs]def summarize_stability(graph: BELGraph) -> Mapping[str, int]:
"""Summarize the stability of the graph."""
regulatory_pairs = get_regulatory_pairs(graph)
chaotic_pairs = get_chaotic_pairs(graph)
dampened_pairs = get_dampened_pairs(graph)
contraditory_pairs = get_contradiction_summary(graph)
separately_unstable_triples = get_separate_unstable_correlation_triples(graph)
mutually_unstable_triples = get_mutually_unstable_correlation_triples(graph)
jens_unstable_triples = get_jens_unstable(graph)
increase_mismatch_triples = get_increase_mismatch_triplets(graph)
decrease_mismatch_triples = get_decrease_mismatch_triplets(graph)
chaotic_triples = get_chaotic_triplets(graph)
dampened_triples = get_dampened_triplets(graph)
return {
'Regulatory Pairs': _count_or_len(regulatory_pairs),
'Chaotic Pairs': _count_or_len(chaotic_pairs),
'Dampened Pairs': _count_or_len(dampened_pairs),
'Contradictory Pairs': _count_or_len(contraditory_pairs),
'Separately Unstable Triples': _count_or_len(separately_unstable_triples),
'Mutually Unstable Triples': _count_or_len(mutually_unstable_triples),
'Jens Unstable Triples': _count_or_len(jens_unstable_triples),
'Increase Mismatch Triples': _count_or_len(increase_mismatch_triples),
'Decrease Mismatch Triples': _count_or_len(decrease_mismatch_triples),
'Chaotic Triples': _count_or_len(chaotic_triples),
'Dampened Triples': _count_or_len(dampened_triples)
}
def _count_or_len(it):
return sum(1 for _ in it)