Question

Formatted question description: https://leetcode.ca/all/222.html

 222	Count Complete Tree Nodes

 Given a complete binary tree, count the number of nodes.

 Note:
     Definition of a complete binary tree from Wikipedia:
         In a complete binary tree
            every level, except possibly the last, is completely filled,
            and all nodes in the last level are as far left as possible.
            It can have between 1 and 2^h nodes inclusive at the last level h.

 Example:

 Input:
    1
   / \
  2   3
 / \  /
4  5 6

 Output: 6

 @tag-tree

Algorithm

First, we need a getHeight function, which is used to count the maximum height of the left subtree of the current node.

  • Call the getHeight function on the current node to get the maximum height h of the left subtree. If returned is -1, it means that the current node does not exist and returns 0 directly.
  • Otherwise, call the getHeight function on the right child.
    • If the return value is h-1, it means that the left sub-tree is a perfect binary tree, and the number of nodes in the left sub-tree is 2^h-1. Adding the current node, there are 2^h in total, that is, 1<<h,
      • then, add the return value of calling the recursive function on the right child node.
    • If the return value of calling the getHeight function on the right subnode is not h-1, it means that the right subtree must be a perfect tree, and the height is h-1,
      • then the number of summary points is 2^(h-1) -1, plus the current node is 2^(h-1), that is, 1<<(h-1), and then add the return value of calling the recursive function on the left child node

Code

Java

  • 
    public class Count_Complete_Tree_Nodes {
        /**
         * Definition for a binary tree node.
         * public class TreeNode {
         *     int val;
         *     TreeNode left;
         *     TreeNode right;
         *     TreeNode(int x) { val = x; }
         * }
         */
    
        public class Solution_optimize {
    
            public int countNodes(TreeNode root) {
    
                int res = 0;
                int h = getHeight(root); // here h is actually h-1
    
                if (h < 0) return 0;
    
                if (getHeight(root.right) == h - 1) {
                    return (1 << h) + countNodes(root.right);
                } else { // getHeight(root.right) == h - 2
                    return (1 << (h - 1)) + countNodes(root.left);
                }
            }
    
            int getHeight(TreeNode node) {
                return node != null ? (1 + getHeight(node.left)) : -1;
            }
        }
    
        public class Solution {
            public int countNodes(TreeNode root) {
                if (root == null) {
                    return 0;
                }
    
                int leftHeight = findLeftHeight(root);
                int rightHeight = findRightHeight(root);
    
                if (leftHeight == rightHeight) {
                    return (2 << (leftHeight - 1)) - 1;
                }
    
                return 1 + countNodes(root.left) + countNodes(root.right);
            }
    
            private int findLeftHeight(TreeNode root) {
                if (root == null) {
                    return 0;
                }
    
                int height = 1;
    
                while (root.left != null) {
                    height++;
                    root = root.left;
                }
    
                return height;
            }
    
            private int findRightHeight(TreeNode root) {
                if (root == null) {
                    return 0;
                }
    
                int height = 1;
    
                while (root.right != null) {
                    height++;
                    root = root.right;
                }
    
                return height;
            }
        }
    }
    
  • // OJ: https://leetcode.com/problems/count-complete-tree-nodes/
    // Time: O(H^2)
    // Space: O(H)
    class Solution {
        int countLeft(TreeNode *root) {
            int cnt = 0;
            for (; root; ++cnt, root = root->left);
            return cnt;
        }
        int countRight(TreeNode *root) {
            int cnt = 0;
            for (; root; ++cnt, root = root->right);
            return cnt;
        }
    public:
        int countNodes(TreeNode* root) {
            if (!root) return 0;
            int left = countLeft(root), right = countRight(root);
            if (left == right) return (1 << left) - 1;
            return countNodes(root->left) + countNodes(root->right) + 1;
        }
    };
    
  • # Definition for a binary tree node.
    # class TreeNode(object):
    #     def __init__(self, x):
    #         self.val = x
    #         self.left = None
    #         self.right = None
    
    class Solution(object):
      def getHeight(self, root):
        height = 0
        while root:
          height += 1
          root = root.left
        return height
    
      def countNodes(self, root):
        count = 0
        while root:
          l, r = map(self.getHeight, (root.left, root.right))
          if l == r:
            count += 2 ** l
            root = root.right
          else:
            count += 2 ** r
            root = root.left
        return count
    
    

All Problems

All Solutions