Welcome to Subscribe On Youtube

Question

Formatted question description: https://leetcode.ca/all/230.html

230. Kth Smallest Element in a BST

Level

Medium

Description

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Note:

You may assume k is always valid, 1 ≤ k ≤ BST’s total elements.

Example 1:

Input: root = [3,1,4,null,2], k = 1
   3
  / \
 1   4
  \
   2
Output: 1

Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1
Output: 3

Follow up:

What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

Solution

The inorder traversal of a binary search tree contains all values in ascending order. Therefore, obtain the inorder traversal of the binary search tree and return the k-th smallest element.

Use a counter, each time a node is traversed, the counter increments by 1, and when the counter reaches k, the current node value.

Code

  • import java.util.Comparator;
    import java.util.PriorityQueue;
    
    public class Kth_Smallest_Element_in_a_BST {
        /**
         * Definition for a binary tree node.
         * public class TreeNode {
         *     int val;
         *     TreeNode left;
         *     TreeNode right;
         *     TreeNode(int x) { val = x; }
         * }
         */
        class Solution {
    
            PriorityQueue<Integer> heap = new PriorityQueue<>(new Comparator<Integer>() {
                @Override
                public int compare(Integer o1, Integer o2) {
                    return o2 - o1;
                }
            });
    
            public int kthSmallest(TreeNode root, int k) {
    
                dfs(root, k);
    
                return heap.peek();
            }
    
            private void dfs(TreeNode root, int k) {
    
                if (root == null) {
                    return;
                }
    
                // maintain heap
                if (heap.size() < k) {
                    heap.offer(root.val);
    
                    // followup question, heap.remove() is by object, not index.
                    // so if delete operation, just remove element from both tree and heap
    
                } else {
                    int val = root.val;
                    if (val < heap.peek()) {
                        heap.poll();
                        heap.offer(val);
                    }
                }
                dfs(root.left, k);
                dfs(root.right, k);
            }
        }
    
    }
    
  • // OJ: https://leetcode.com/problems/kth-smallest-element-in-a-bst/
    // Time: O(N)
    // Space: O(H)
    class Solution {
    public:
        int kthSmallest(TreeNode* root, int k) {
            function<int(TreeNode*)> inorder = [&](TreeNode *root) {
                if (!root) return -1;
                int val = inorder(root->left);
                if (val != -1) return val;
                if (--k == 0) return root->val;
                return inorder(root->right);
            };
            return inorder(root);
        }
    };
    
  • # Definition for a binary tree node.
    # class TreeNode:
    #     def __init__(self, val=0, left=None, right=None):
    #         self.val = val
    #         self.left = left
    #         self.right = right
    class Solution:
        def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
            stk = []
            while root or stk:
                if root:
                    stk.append(root)
                    root = root.left
                else:
                    root = stk.pop()
                    k -= 1
                    if k == 0:
                        return root.val
                    root = root.right
    
    ############
    
    # Definition for a binary tree node.
    # class TreeNode(object):
    #     def __init__(self, x):
    #         self.val = x
    #         self.left = None
    #         self.right = None
    
    class Solution(object):
      def kthSmallest(self, root, k):
        """
        :type root: TreeNode
        :type k: int
        :rtype: int
        """
        stack = [(1, root)]
        while stack:
          cmd, p = stack.pop()
          if not p:
            continue
          if cmd == 0:
            k -= 1
            if k == 0:
              return p.val
          else:
            stack.append((1, p.right))
            stack.append((0, p))
            stack.append((1, p.left))
    
    

Solution - Follow up

Each recursion traverses all the nodes of the left subtree to calculate the number of operations is not efficient, so the structure of the original tree node should be modified, to save the number of nodes including the current node and its left and right subtrees.

In this way, the total number of any left subtree nodes can be quickly obtained to quickly locate the target value.

Define the new node structure, and then generate a new tree, or use a recursive method to generate a new tree. Pay attention to the count value of the generated node to add the count value of its left and right child nodes.

Then in the function for finding the k-th smallest element, first generate a new tree, and then call the recursive function.

In a recursive function, you cannot directly access the count value of the left child node, because the left child node does not necessarily exist, so you must first judge

  • If the left child node exists, then the operation is the same as the above solution
  • If it does not exist
    • when k is 1, the current node value is directly returned
    • otherwise, the recursive function is called on the right child node, and k is decremented by 1
  • class Solution_followUp {
        public int kthSmallest(TreeNode root, int k) {
            MyTreeNode node = build(root);
            return dfs(node, k);
        }
    
        class MyTreeNode {
            int val;
            int count; // key point to add up and find k-th element
            MyTreeNode left;
            MyTreeNode right;
            MyTreeNode(int x) {
                this.val = x;
                this.count = 1;
            }
        };
    
        MyTreeNode build(TreeNode root) {
            if (root == null) return null;
            MyTreeNode node = new MyTreeNode(root.val);
            node.left = build(root.left);
            node.right = build(root.right);
            if (node.left != null) node.count += node.left.count;
            if (node.right != null) node.count += node.right.count;
            return node;
        }
    
        int dfs(MyTreeNode node, int k) {
            if (node.left != null) {
                int cnt = node.left.count;
                if (k <= cnt) {
                    return dfs(node.left, k);
                } else if (k > cnt + 1) {
                    return dfs(node.right, k - 1 - cnt); // -1 is to exclude current root
                } else { // k == cnt + 1
                    return node.val;
                }
            } else {
                if (k == 1) return node.val;
                return dfs(node.right, k - 1);
            }
        }
    }
    
  • // Follow up
    class Solution_followUp {
    public:
        struct MyTreeNode {
            int val;
            int count;
            MyTreeNode *left;
            MyTreeNode *right;
            MyTreeNode(int x) : val(x), count(1), left(NULL), right(NULL) {}
        };
        
        MyTreeNode* build(TreeNode* root) {
            if (!root) return NULL;
            MyTreeNode *node = new MyTreeNode(root->val);
            node->left = build(root->left);
            node->right = build(root->right);
            if (node->left) node->count += node->left->count;
            if (node->right) node->count += node->right->count;
            return node;
        }
        
        int kthSmallest(TreeNode* root, int k) {
            MyTreeNode *node = build(root);
            return helper(node, k);
        }
        
        int helper(MyTreeNode* node, int k) {
            if (node->left) {
                int cnt = node->left->count;
                if (k <= cnt) {
                    return helper(node->left, k);
                } else if (k > cnt + 1) {
                    return helper(node->right, k - 1 - cnt);
                }
                return node->val;
            } else {
                if (k == 1) return node->val;
                return helper(node->right, k - 1);
            }
        }
    };
    
  • # Definition for a binary tree node.
    # class TreeNode:
    #     def __init__(self, val=0, left=None, right=None):
    #         self.val = val
    #         self.left = left
    #         self.right = right
    class Solution_followUp:
        def kthSmallest(self, root: TreeNode, k: int) -> int:
            node = self.build(root)
            return self.dfs(node, k)
    
        class MyTreeNode:
            def __init__(self, x):
                self.val = x
                self.count = 1
                self.left = None
                self.right = None
    
        def build(self, root: TreeNode) -> MyTreeNode:
            if not root:
                return None
            node = self.MyTreeNode(root.val)
            node.left = self.build(root.left)
            node.right = self.build(root.right)
            if node.left:
                node.count += node.left.count
            if node.right:
                node.count += node.right.count
            return node
    
        def dfs(self, node: MyTreeNode, k: int) -> int:
            if node.left:
                cnt = node.left.count
                if k <= cnt:
                    return self.dfs(node.left, k)
                elif k > cnt + 1:
                    return self.dfs(node.right, k - 1 - cnt)
                else:
                    return node.val
            else:
                if k == 1:
                    return node.val
                return self.dfs(node.right, k - 1)
    
    ############
    
    # Definition for a binary tree node.
    # class TreeNode(object):
    #     def __init__(self, x):
    #         self.val = x
    #         self.left = None
    #         self.right = None
    
    class Solution(object):
      def kthSmallest(self, root, k):
        """
        :type root: TreeNode
        :type k: int
        :rtype: int
        """
        stack = [(1, root)]
        while stack:
          cmd, p = stack.pop()
          if not p:
            continue
          if cmd == 0:
            k -= 1
            if k == 0:
              return p.val
          else:
            stack.append((1, p.right))
            stack.append((0, p))
            stack.append((1, p.left))
    
    

All Problems

All Solutions