Company Tree
import java.util.*;
class Node {
int val;
ArrayList<Node> children;
public Node(int val){
this.val = val;
children = new ArrayList<Node>();
}
}
class Record {
int sum;
int count;
public Record(int sum, int count) {
this.sum = sum;
this.count = count;
}
}
public class CompanyTree {
private static double resAve = Double.MIN_VALUE;
private static Node result;
public static Node getHighAve(Node root){
if (root == null) return null;
dfs1(root);
return result;
}
public static int[] dfs1(Node root) {
if (root.children == null || root.children.size() == 0) {
return new int[] {root.val, 1};
}
int curSum = root.val, curCount = 1;
for (Node child: root.children) {
int[] record = dfs1(child);
curSum += record[0];
curCount += record[1];
}
double ave = (double) curSum / curCount;
if (ave > resAve) {
resAve = ave;
result = root;
}
return new int[] {curSum, curCount};
}
public static Record dfs(Node root){
if (root.children == null || root.children.size() == 0) {
return new Record(root.val, 1);
}
int curSum = root.val;
int curCnt = 1;
for (Node child : root.children) {
Record record = dfs(child);
curSum += record.sum;
curCnt += record.count;
}
double curAve = (double) curSum / curCnt;
if (resAve < curAve){
resAve = curAve;
result = root;
}
return new Record(curSum,curCnt);
}
public static void main(String[] args) {
Node root = new Node(1);
Node l21 = new Node(2);
Node l22 = new Node(3);
Node l23 = new Node(4);
Node l31 = new Node(5);
Node l32 = new Node(5);
Node l33 = new Node(5);
Node l34 = new Node(5);
Node l35 = new Node(5);
Node l36 = new Node(5);
l21.children.add(l31);
l21.children.add(l32);
l22.children.add(l33);
l22.children.add(l34);
l23.children.add(l35);
l23.children.add(l36);
root.children.add(l21);
root.children.add(l22);
root.children.add(l23);
Node res = getHighAve(root);
System.out.println(res.val + " " + resAve);
}
}