Instagram
youtube
Facebook
Twitter

Kundu and Tree| Disjoint Set

Task(hard)

Kundu is true tree lover. Tree is a connected graph having N vertices and N-1 edges. Today when he got a tree, he colored each edge with one of either red(r) or black(b) color. He is interested in knowing how many triplets(a,b,c) of vertices are there , such that, there is atleast one edge having red color on all the three paths i.e. from vertex a to b, vertex b to c and vertex c to a . Note that (a,b,c), (b,a,c) and all such permutations will be considered as the same triplet. If the answer is greater than 109 + 7, print the answer modulo (%) 109 + 7.

Input Format
The first line contains an integer N, i.e., the number of vertices in tree.
The next N-1 lines represent edges: 2 space separated integers denoting an edge followed by a color of the edge. A color of an edge is denoted by a small letter of English alphabet, and it can be either red(r) or black(b).

Output Format
Print a single number i.e. the number of triplets.

Constraints

1 ≤ N ≤ 105
A node is numbered between 1 to N.

Sample Output

4

SOLUTION 1

from decimal import Decimal

class DisjointSet:

    def __init__(self):

        self.parent=self

        self.size=1

    def findParent(self):

        if self.parent!=self:

            self.parent=self.parent.findParent()

        return self.parent

    def union(self,other):

        if self==other:

            return

        root=self.findParent()

        other_root=other.findParent()

        if root==other_root:

            return

        if root.size>other_root.size:

            other_root.parent=root

            root.size+=other_root.size

        else:

            root.parent=other_root

            other_root.size+=root.size

def nc2(n):

    if n<2:

        return 0

    return Decimal(n*(n-1)/2)

def nc3(n):

    if n<3:

        return 0

    return Decimal(n*(n-1)*(n-2)/6)

n=int(input())

components=[None]*n

for i in range(n-1):

    a,b,c=input().split()

    a=int(a)-1

    b=int(b)-1

    if c=='r':

        continue

    if not components[a]:

        components[a]=DisjointSet()

    if not components[b]:

        components[b]=DisjointSet()

    components[a].union(components[b])

uniqueComponents=set()

for x in components:

    if x:

        uniqueComponents.add(x.findParent())

valid_triplets=Decimal(nc3(n))

for x in uniqueComponents:

    valid_triplets-=nc3(x.size)

    valid_triplets-=nc2(x.size)*(n-x.size)

print(int(valid_triplets)%(10**9+7))

 

EXPLANATION STEPS

1.Understand the Problem Statement: Objective: Find the maximum distance between any two nodes in a given tree.

2.Graph Representation: Use an adjacency list to represent the tree where each node has a list of its neighbors.

3.Initialize Data Structures: Graph Structure: Create an adjacency list to store the tree. Visited Array: Track visited nodes during traversal.

4.Implement Tree Traversal: BFS/DFS Function: Use Breadth-First Search (BFS) or Depth-First Search (DFS) to explore the tree. BFS: Preferred for finding the farthest node efficiently.

5.Two-Pass BFS/DFS Algorithm to Find Diameter: Start BFS/DFS from any arbitrary node to find the farthest node from it. Let this node be A.

6.Output the Result: Output the maximum distance found in the second pass.