## Implementing special sets in Java: Partition a.k.a. DisjointSet

Recently we implemented an algorithm for a customer product that involved complex data structures for calculating conflicts in concurrent data updates. A performance critical part of the algorithm required comparing whether two elements are in the same set or not and additionally, required calculating the union of sets of elements. While the first operation is pretty fast with a HashSet (constant time), the later is pretty expensive (linear time). There is, however, a data structure called Disjoint Set or Union Find, that is able to deliver both operations in (almost*****) constant amortized time. Unfortunately, neither Java nor popular libraries such as Google Guava or Apache Commons ship with an implementation. However, implementing the data structure is straight-forward. In this blog post, we will implement a DisjointSet.

In my opinion the name DisjointSet is a bit weird, therefore I named the respective classes differently. From a mathematical point of view a DisjointSet is the partition of a set. It partitions the set into subsets such that every element is contained in exactly one subset. The class Partition contains a set of PartitionSubSets. The Partition interface contains three methods:

```
public interface Partition{
PartitionSubSet makeSet(Object object);
PartitionSubSet find(Object object);
PartitionSubSet union(PartitionSubSet subSet1, PartitionSubSet subSet2)
throws UnkownPartitionSubSetException;
}
```

MakeSet will create a PartitionSubSet containing only the given object, if the object was not previously contained in a subset of the Partition. If the object was already part of the Partition, the PartitionSubSet that the object is part of, will be returned instead.

Find will get a representative object for an object in the set that can be used to identify the set. This can be used to find out if two objects a and b are in the same set:

```
boolean containedInSameSet(Object a, Object b) {
return find(a) == find(b)
}
```

Union will create a union of the two sets and return the resulting set as PartitionSubSet.

DisjointSets are implemented as a forest (set of trees). Each tree represents a set where each entry in the tree (leaf, inner node and root) is an element in the set. Each element in the tree points to its parent element or to null if it is the root. Our implementation will be called TreePartition and implement the Partition interface. To represent the tree nodes we create a class PartitionElement. PartitionElement will wrap objects (values) and maintain a reference to its parent PartitionElement or null if the PartitionElement is the root of the tree. In addition, it contains an integer field defining the rank of the PartitionElement which we will need later on to implement the union method.

```
public class PartitionElement implements PartitionSubSet {
private Object value;
private int rank;
private PartitionElement parent;
public PartitionElement(Object value) {
this.value = value;
}
public Object getValue() {
return value;
}
public int getRank() {
return rank;
}
public void incRank() {
this.rank += 1;
}
public PartitionElement getParent() {
return parent;
}
public void setParent(PartitionElement parent) {
this.parent = parent;
}
}
```

The implementation of TreePartition is as follows. I will explain it method by method in the remainder of this post.

```
import java.util.LinkedHashMap;
public class TreePartition implements Partition {
private LinkedHashMap objectToPartitionElementMap;
public TreePartition() {
objectToPartitionElementMap = new LinkedHashMap();
}
public PartitionSubSet find(Object object) {
return findPartitionElement(getElement(object));
}
public PartitionElement findPartitionElement(PartitionElement
partitionElement) {
PartitionElement parent = partitionElement.getParent();
if (parent==null) {
return partitionElement;
}
PartitionElement root = findPartitionElement(parent);
partitionElement.setParent(root);
return root;
}
private PartitionElement getElement(Object object) {
PartitionElement partitionElement =
objectToPartitionElementMap.get(object);
if (partitionElement==null) {
partitionElement = new PartitionElement(object);
objectToPartitionElementMap.put(object,
partitionElement);
}
return partitionElement;
}
public PartitionSubSet makeSet(Object object) {
return findPartitionElement(getElement(object));
}
public PartitionSubSet union(PartitionSubSet subSet1,
PartitionSubSet subSet2)
throws UnkownPartitionSubSetException {
if (!(subSet1 instanceof PartitionElement) ||
!(subSet2 instanceof PartitionElement)) {
throw new UnkownPartitionSubSetException();
}
PartitionElement subSet1Root = (PartitionElement) subSet1;
PartitionElement subSet2Root = (PartitionElement) subSet2;
PartitionSubSet result;
if (subSet1Root.getRank() > subSet2Root.getRank()) {
subSet2Root.setParent(subSet1Root);
result = subSet1;
}
else if (subSet1Root.getRank() < subSet2Root.getRank()) {
subSet1Root.setParent(subSet2Root);
result = subSet2;
}
else {
subSet1Root.incRank();
subSet2Root.setParent(subSet1Root);
result = subSet1;
}
return result;
}
}
```

The TreePartition will use a HashMap to map object (values) to their respective PartitionElement. The private method getElement() will return a PartitionElement for a given object and create a new PartitionElement if there is none for an object. The private method findPartitionElement will determine the root of a given element by walking over the parent references until it reaches root and the parent reference is null. It employs an optimization called path compression that helps to achieve the amortized linear time for one call of find. During the iteration to the root, findPartitionElement will rewrite the parent references of all traversed elements to the root element. This will reduce the time complexity for future calls and over time will pull up all leaves and inner nodes as direct children of the root of the tree.

The find method now only needs to get the PartitionElement for an object and pass it on to findPartitionElement().

The makeSet method is identical to find, since getElement() will create a PartitionElement for a new object.

The union method will merge two trees by making the root of one tree the parent of the root of the other tree. To avoid unbalanced trees with a large depth compared to the number of contained elements, a rank is used to maintain an upper bound on depth. A tree with a larger rank is considered to be deeper and therefore the tree with the smaller rank will be added as subtree to the tree with the larger rank. If the ranks of two trees are identical the resulting tree will have its rank increased. This optimization is called union-by-rank. Union-by-rank in combination with path compression, is the reason for the linear time complexity mentioned above, without these optimizations, time complexity would be linear for the find method.

Wikipedia provides a pretty solid article with pointers to more information.

*****almost: In fact the time complexity is in O(a(n)) where a(n) is the inverse of the Ackermann function A(n,n). a(n) is less than 5 for practically all values of n.