- Published at
Exact Marginal Estimation via Junction Trees
A short introduction into mathematical perspective on exact marginal inference on graphs via Junction Tree Algorithm
- Authors
- Name
- Dr. Dmitrij Sitenko
- Mathematical Image Processing at Ruprecht Karls University Heidelberg
Junction trees, also known as clique trees comprises a particular type of data structure used in probabilistic graphical models for the inference of exact marginals of a given joint probability distribution.
We start with a simple graphical model constisting of four nodes that form tree.
Knowing the joint distribution our goal are the marginals . A simple apporach is the marginalization over the complementary set, i.e. for
which requires summations if we assume to be a categorical random variable taking values in the finite set . However, for larger graphs this procedure scales exponential with the number of nodes in the graph resulting in summations for each marginal which is more then the number of atoms in the universe! A way out of this hopeless apporach is given by the underlying tree structure of the graph with respect to which the disribution can be rewritten as
As the node separates the connecting nodes and and therefore by conditional independence using and the joint probability distribution factorizes into the product of conditional distributions
So far nothing is won, because taking off from the right hand side we still have to sum over nodes . However, if we would have knowledge on the edge marginals , we could reselve this problem by just taking the sum over . This is the essence why estimating the marginals of a large joint probability distribution concerns so much attention in the scientific communnity.
Following the steps we have now reduced the marginalization problem to inference task of edge marginals. But how do we arrive at the right hand side if we have more complecated graphs ? Before answering this quastion we state an alternative factorization of
Now even if we dont have acces to true marginals we still have the global distribution from the left and factorization from the right which we can use to model our distribution
which agrees with the previous factorization for the particular case and . Now given any factorization with respect to potential functions how can we effiently acces the marginal at ? To get an intuition we consider a simple scenario of a chain consisting of three nodes and
We want the updates of local potentials and that in the limits satisfy a constaint
Now assume that this local consistancy condition is not satisfied , that is
at the next iteration we want fix this by first updating the separator node marginal
and then normalize the edge potential
For the above steps to be admissible two aspect must hold
: The factorization with respect does not effect the joint probability, i.e.
: the local consistancy condition at :
which is still violated due to by assumption. However, what we have achieved is : marginalization condition at the edge as
This process of propagating the marginal from leaf node over separator node over to leaf node is commonly denoted as the forward pass which leads the true marginal distribution . Now by contruction we can do the same in the reverse direction by repeating the steps
and then updating the edge potential
which defines the backward pass
and we next check again the three conditions
Joint probability distribution:
Local consistancy:
Correct edge marginals:
and consequently we obtain a new factorization of the joint probabilty with
We can generalize the above procedure from chain graph to simple tree graph by including node adding the missing compenents in the factorization of in the forward pass to yield correct edge potential via additional marginalization over new node .
Simply put we the above update corresponds to a simple forward pass where now we have a set of messages coming through a separating node from from edges connected to it. The of message passing updates now comes from the fact that we can interpret the current marginals at as messages from marginalized node over over to leaf node . Now repeating the steps of the backward pass but now with additional edge potential it follows
It can easely be checked that all condition hold and we get the correct marginals from the backward pass. The following graphic illustrates the steps
More generally on a tree graph with node size each tree-structured distribution factorizes via local marginals
where:
- are the marginal probability at node on a tree graph ,
- are the pairwise joint probabilities on edges ,
- The fraction is now the desired representation of in terms of marginals.
This can be seen as follows
First, each node on a tree is conditionally independent of all other nodes given its neighbors, i.e. the joint distribution satisfies the pairwise Markov property:
where is the set of neighboring nodes of . The joint distribution can be expanded using the chain rule for probabilities
where refers to the parent of vertex with respect to some arbitrary root of the tree.
Second, for each vertex , its conditional distribution can be expressed using the joint marginal and the marginal via
which after insertion into the joint probability yields the desired factorization
where the last equality can be shown using induction over .
As next more sophisticated example we consider the following graphical model with cycles on the ineteger grid
with an underlying distribution . We want to go for the marginals . A naive approach what we can do, say for accessing is taking a sum over all nodes except which results in
which scales exponentially as with the node dimension . However we can follow the same idea and determine by performing message passing on a tree graph. However, due to presence of cycles we introduce a slight generalization of tree graph the so called junction trees.
A junctio tree, consists of maximal cliques (subsets of fully connected nodes) of an triangulated graph, i.e. a graph with cycles of maximal length 3. The edges are defined in terms of the separater sets, i.e. intersection of adjacent cliques. In addition, the tree must satisfy the running intersection property, that is for each two maximal cliques and a node all cliques and separater nodes on the path between and include The following figure illustrates a valid junctio tree for the above grid graph.
#include <iostream>
#include <vector>
#include <unordered_set>
using namespace std;
struct Node {
int id;
vector<int> neighbors;
};
struct Clique {
unordered_set<int> nodes;
};
class Graph {
int V;
vector<Node> nodes;
public:
Graph(int V);
void addEdge(int v, int w);
vector<Clique> triangulate();
void printJunctionTree(const vector<Clique>& cliques);
};
Graph::Graph(int V) {
this->V = V;
nodes.resize(V);
for (int i = 0; i < V; ++i) {
nodes[i].id = i;
}
}
void Graph::addEdge(int v, int w) {
nodes[v].neighbors.push_back(w);
nodes[w].neighbors.push_back(v);
}
// Triangulation using minimum fill-in heuristic
vector<Clique> Graph::triangulate() {
vector<Clique> cliques;
// Temporary vector to store the triangulated nodes
vector<bool> processed(V, false);
for (int i = 0; i < V; ++i) {
if (!processed[i]) {
// Find the neighbors of current node
unordered_set<int> neighbors(nodes[i].neighbors.begin(), nodes[i].neighbors.end());
// Add the node and its neighbors to a new clique
Clique clique;
clique.nodes.insert(i);
for (int neighbor : nodes[i].neighbors) {
clique.nodes.insert(neighbor);
for (int n : nodes[neighbor].neighbors) {
if (n != i)
clique.nodes.insert(n);
}
}
// Add the new clique to the junction tree
cliques.push_back(clique);
// Mark all nodes in the new clique as processed
for (int node : clique.nodes)
processed[node] = true;
}
}
return cliques;
}
void Graph::printJunctionTree(const vector<Clique>& cliques) {
cout << "Junction Tree:" << endl;
for (int i = 0; i < cliques.size(); ++i) {
cout << "Clique " << i << ": ";
for (int node : cliques[i].nodes) {
cout << node << " ";
}
cout << endl;
}
}
int main() {
Graph graph(6);
graph.addEdge(0, 1);
graph.addEdge(0, 2);
graph.addEdge(1, 2);
graph.addEdge(1, 3);
graph.addEdge(2, 4);
graph.addEdge(3, 4);
graph.addEdge(3, 5);
graph.addEdge(4, 5);
vector<Clique> cliques = graph.triangulate();
graph.printJunctionTree(cliques);
return 0;
}