Question

Formatted question description: https://leetcode.ca/all/230.html

230. Kth Smallest Element in a BST

Level

Medium

Description

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Note:

You may assume k is always valid, 1 ≤ k ≤ BST’s total elements.

Example 1:

Input: root = [3,1,4,null,2], k = 1
   3
  / \
 1   4
  \
   2
Output: 1

Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1
Output: 3

Follow up:

What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

Solution

The inorder traversal of a binary search tree contains all values in ascending order. Therefore, obtain the inorder traversal of the binary search tree and return the kth smallest element.

Use a counter, each time a node is traversed, the counter increments by 1, and when the counter reaches k, the current node value.

Code

Java

import java.util.Comparator;
import java.util.PriorityQueue;

public class Kth_Smallest_Element_in_a_BST {
    /**
     * Definition for a binary tree node.
     * public class TreeNode {
     *     int val;
     *     TreeNode left;
     *     TreeNode right;
     *     TreeNode(int x) { val = x; }
     * }
     */
    class Solution {

        PriorityQueue<Integer> heap = new PriorityQueue<>(new Comparator<Integer>() {
            @Override
            public int compare(Integer o1, Integer o2) {
                return o2 - o1;
            }
        });

        public int kthSmallest(TreeNode root, int k) {

            dfs(root, k);

            return heap.peek();
        }

        private void dfs(TreeNode root, int k) {

            if (root == null) {
                return;
            }

            // maintain heap
            if (heap.size() < k) {
                heap.offer(root.val);

                // followup question, heap.remove() is by object, not index.
                // so if delete operation, just remove element from both tree and heap

            } else {
                int val = root.val;
                if (val < heap.peek()) {
                    heap.poll();
                    heap.offer(val);
                }
            }
            dfs(root.left, k);
            dfs(root.right, k);
        }
    }

}

Solution - Follow up

Each recursion traverses all the nodes of the left subtree to calculate the number of operations is not efficient, so the structure of the original tree node should be modified, to save the number of nodes including the current node and its left and right subtrees.

In this way, the total number of any left subtree nodes can be quickly obtained to quickly locate the target value.

Define the new node structure, and then generate a new tree, or use a recursive method to generate a new tree. Pay attention to the count value of the generated node to add the count value of its left and right child nodes.

Then in the function for finding the k-th smallest element, first generate a new tree, and then call the recursive function.

In a recursive function, you cannot directly access the count value of the left child node, because the left child node does not necessarily exist, so you must first judge

  • If the left child node exists, then the operation is the same as the above solution
  • If it does not exist
    • when k is 1, the current node value is directly returned
    • otherwise, the recursive function is called on the right child node, and k is decremented by 1
class Solution_followUp {
    public int kthSmallest(TreeNode root, int k) {
        MyTreeNode node = build(root);
        return dfs(node, k);
    }

    class MyTreeNode {
        int val;
        int count; // key point to add up and find k-th element
        MyTreeNode left;
        MyTreeNode right;
        MyTreeNode(int x) {
            this.val = x;
            this.count = 1;
        }
    };

    MyTreeNode build(TreeNode root) {
        if (root == null) return null;
        MyTreeNode node = new MyTreeNode(root.val);
        node.left = build(root.left);
        node.right = build(root.right);
        if (node.left != null) node.count += node.left.count;
        if (node.right != null) node.count += node.right.count;
        return node;
    }

    int dfs(MyTreeNode node, int k) {
        if (node.left != null) {
            int cnt = node.left.count;
            if (k <= cnt) {
                return dfs(node.left, k);
            } else if (k > cnt + 1) {
                return dfs(node.right, k - 1 - cnt); // -1 is to exclude current root
            } else {
                return node.val;
            }
        } else {
            if (k == 1) return node.val;
            return dfs(node.right, k - 1);
        }
    }
}

cpp

class Solution {
public:
    int kthSmallest(TreeNode* root, int k) {
        int cnt = count(root->left);
        if (k <= cnt) {
            return kthSmallest(root->left, k);
        } else if (k > cnt + 1) {
            return kthSmallest(root->right, k - cnt - 1);
        }
        return root->val;
    }
    int count(TreeNode* node) {
        if (!node) return 0;
        return 1 + count(node->left) + count(node->right);
    }
};

// Follow up
class Solution_followUp {
public:
    struct MyTreeNode {
        int val;
        int count;
        MyTreeNode *left;
        MyTreeNode *right;
        MyTreeNode(int x) : val(x), count(1), left(NULL), right(NULL) {}
    };
    
    MyTreeNode* build(TreeNode* root) {
        if (!root) return NULL;
        MyTreeNode *node = new MyTreeNode(root->val);
        node->left = build(root->left);
        node->right = build(root->right);
        if (node->left) node->count += node->left->count;
        if (node->right) node->count += node->right->count;
        return node;
    }
    
    int kthSmallest(TreeNode* root, int k) {
        MyTreeNode *node = build(root);
        return helper(node, k);
    }
    
    int helper(MyTreeNode* node, int k) {
        if (node->left) {
            int cnt = node->left->count;
            if (k <= cnt) {
                return helper(node->left, k);
            } else if (k > cnt + 1) {
                return helper(node->right, k - 1 - cnt);
            }
            return node->val;
        } else {
            if (k == 1) return node->val;
            return helper(node->right, k - 1);
        }
    }
};

All Problems

All Solutions