Welcome to Subscribe On Youtube

Question

Formatted question description: https://leetcode.ca/all/230.html

230. Kth Smallest Element in a BST

Level

Medium

Description

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Note:

You may assume k is always valid, 1 ≤ k ≤ BST’s total elements.

Example 1:

Input: root = [3,1,4,null,2], k = 1
   3
  / \
 1   4
  \
   2
Output: 1

Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1
Output: 3

Follow up:

What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

Solution

The inorder traversal of a binary search tree contains all values in ascending order. Therefore, obtain the inorder traversal of the binary search tree and return the k-th smallest element.

Use a counter, each time a node is traversed, the counter increments by 1, and when the counter reaches k, the current node value.

Solution - Follow up

Each recursion traverses all the nodes of the left subtree to calculate the number of operations is not efficient, so the structure of the original tree node should be modified, to save the number of nodes including the current node and its left and right subtrees.

In this way, the total number of any left subtree nodes can be quickly obtained to quickly locate the target value.

Define the new node structure, and then generate a new tree, or use a recursive method to generate a new tree. Pay attention to the count value of the generated node to add the count value of its left and right child nodes.

Then in the function for finding the k-th smallest element, first generate a new tree, and then call the recursive function.

In a recursive function, you cannot directly access the count value of the left child node, because the left child node does not necessarily exist, so you must first judge

  • If the left child node exists, then the operation is the same as the above solution
  • If it does not exist
    • when k is 1, the current node value is directly returned
    • otherwise, the recursive function is called on the right child node, and k is decremented by 1

Code

  • import java.util.Comparator;
    import java.util.PriorityQueue;
    
    public class Kth_Smallest_Element_in_a_BST {
        /**
         * Definition for a binary tree node.
         * public class TreeNode {
         *     int val;
         *     TreeNode left;
         *     TreeNode right;
         *     TreeNode(int x) { val = x; }
         * }
         */
        class Solution {
    
            PriorityQueue<Integer> heap = new PriorityQueue<>(new Comparator<Integer>() {
                @Override
                public int compare(Integer o1, Integer o2) {
                    return o2 - o1;
                }
            });
    
            public int kthSmallest(TreeNode root, int k) {
    
                dfs(root, k);
    
                return heap.peek();
            }
    
            private void dfs(TreeNode root, int k) {
    
                if (root == null) {
                    return;
                }
    
                // maintain heap
                if (heap.size() < k) {
                    heap.offer(root.val);
    
                    // followup question, heap.remove() is by object, not index.
                    // so if delete operation, just remove element from both tree and heap
    
                } else {
                    int val = root.val;
                    if (val < heap.peek()) {
                        heap.poll();
                        heap.offer(val);
                    }
                }
                dfs(root.left, k);
                dfs(root.right, k);
            }
        }
    
    }
    
    class Solution_followUp {
        public int kthSmallest(TreeNode root, int k) {
            MyTreeNode node = build(root);
            return dfs(node, k);
        }
    
        class MyTreeNode {
            int val;
            int count; // key point to add up and find k-th element
            MyTreeNode left;
            MyTreeNode right;
            MyTreeNode(int x) {
                this.val = x;
                this.count = 1;
            }
        };
    
        MyTreeNode build(TreeNode root) {
            if (root == null) return null;
            MyTreeNode node = new MyTreeNode(root.val);
            node.left = build(root.left);
            node.right = build(root.right);
            if (node.left != null) node.count += node.left.count;
            if (node.right != null) node.count += node.right.count;
            return node;
        }
    
        int dfs(MyTreeNode node, int k) {
            if (node.left != null) {
                int cnt = node.left.count;
                if (k <= cnt) {
                    return dfs(node.left, k);
                } else if (k > cnt + 1) {
                    return dfs(node.right, k - 1 - cnt); // -1 is to exclude current root
                } else { // k == cnt + 1
                    return node.val;
                }
            } else {
                if (k == 1) return node.val;
                return dfs(node.right, k - 1);
            }
        }
    }
    
    ############
    
    /**
     * Definition for a binary tree node.
     * public class TreeNode {
     *     int val;
     *     TreeNode left;
     *     TreeNode right;
     *     TreeNode() {}
     *     TreeNode(int val) { this.val = val; }
     *     TreeNode(int val, TreeNode left, TreeNode right) {
     *         this.val = val;
     *         this.left = left;
     *         this.right = right;
     *     }
     * }
     */
    class Solution {
        public int kthSmallest(TreeNode root, int k) {
            Deque<TreeNode> stk = new ArrayDeque<>();
            while (root != null || !stk.isEmpty()) {
                if (root != null) {
                    stk.push(root);
                    root = root.left;
                } else {
                    root = stk.pop();
                    if (--k == 0) {
                        return root.val;
                    }
                    root = root.right;
                }
            }
            return 0;
        }
    }
    
  • // OJ: https://leetcode.com/problems/kth-smallest-element-in-a-bst/
    // Time: O(N)
    // Space: O(H)
    class Solution {
    public:
        int kthSmallest(TreeNode* root, int k) {
            function<int(TreeNode*)> inorder = [&](TreeNode *root) {
                if (!root) return -1;
                int val = inorder(root->left);
                if (val != -1) return val;
                if (--k == 0) return root->val;
                return inorder(root->right);
            };
            return inorder(root);
        }
    };
    
  • # Definition for a binary tree node.
    # class TreeNode:
    #     def __init__(self, val=0, left=None, right=None):
    #         self.val = val
    #         self.left = left
    #         self.right = right
    class Solution:
        def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
            stk = []
            while root or stk:
                if root:
                    stk.append(root)
                    root = root.left
                else:
                    root = stk.pop()
                    k -= 1
                    if k == 0:
                        return root.val
                    root = root.right
    
    ############
    
    # Definition for a binary tree node.
    # class TreeNode:
    #     def __init__(self, val=0, left=None, right=None):
    #         self.val = val
    #         self.left = left
    #         self.right = right
    class Solution_followUp:
        def kthSmallest(self, root: TreeNode, k: int) -> int:
            node = self.build(root)
            return self.dfs(node, k)
    
        class MyTreeNode:
            def __init__(self, x):
                self.val = x
                self.count = 1 # default 1
                self.left = None
                self.right = None
    
        def build(self, root: TreeNode) -> MyTreeNode:
            if not root:
                return None
            node = self.MyTreeNode(root.val) # default count already is 1
            node.left = self.build(root.left)
            node.right = self.build(root.right)
            if node.left:
                node.count += node.left.count
            if node.right:
                node.count += node.right.count
            return node
    
        def dfs(self, node: MyTreeNode, k: int) -> int:
            if node.left:
                cnt = node.left.count
                if k < cnt + 1:
                    return self.dfs(node.left, k)
                elif k > cnt + 1:
                    return self.dfs(node.right, k - 1 - cnt)
                else: # k == cnt+1
                    return node.val
            else:
                if k == 1: # cannot move to beginning of dfs()
                    return node.val
                return self.dfs(node.right, k - 1)
    
    
  • /**
     * Definition for a binary tree node.
     * type TreeNode struct {
     *     Val int
     *     Left *TreeNode
     *     Right *TreeNode
     * }
     */
    func kthSmallest(root *TreeNode, k int) int {
    	stk := []*TreeNode{}
    	for root != nil || len(stk) > 0 {
    		if root != nil {
    			stk = append(stk, root)
    			root = root.Left
    		} else {
    			root = stk[len(stk)-1]
    			stk = stk[:len(stk)-1]
    			k--
    			if k == 0 {
    				return root.Val
    			}
    			root = root.Right
    		}
    	}
    	return 0
    }
    
  • /**
     * Definition for a binary tree node.
     * class TreeNode {
     *     val: number
     *     left: TreeNode | null
     *     right: TreeNode | null
     *     constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
     *         this.val = (val===undefined ? 0 : val)
     *         this.left = (left===undefined ? null : left)
     *         this.right = (right===undefined ? null : right)
     *     }
     * }
     */
    
    function kthSmallest(root: TreeNode | null, k: number): number {
        const dfs = (root: TreeNode | null) => {
            if (root == null) {
                return -1;
            }
            const { val, left, right } = root;
            const l = dfs(left);
            if (l !== -1) {
                return l;
            }
            k--;
            if (k === 0) {
                return val;
            }
            return dfs(right);
        };
        return dfs(root);
    }
    
    
  • // Definition for a binary tree node.
    // #[derive(Debug, PartialEq, Eq)]
    // pub struct TreeNode {
    //   pub val: i32,
    //   pub left: Option<Rc<RefCell<TreeNode>>>,
    //   pub right: Option<Rc<RefCell<TreeNode>>>,
    // }
    //
    // impl TreeNode {
    //   #[inline]
    //   pub fn new(val: i32) -> Self {
    //     TreeNode {
    //       val,
    //       left: None,
    //       right: None
    //     }
    //   }
    // }
    use std::rc::Rc;
    use std::cell::RefCell;
    impl Solution {
        fn dfs(root: Option<Rc<RefCell<TreeNode>>>, res: &mut Vec<i32>, k: usize) {
            if let Some(node) = root {
                let mut node = node.borrow_mut();
                Self::dfs(node.left.take(), res, k);
                res.push(node.val);
                if res.len() >= k {
                    return;
                }
                Self::dfs(node.right.take(), res, k);
            }
        }
        pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
            let k = k as usize;
            let mut res: Vec<i32> = Vec::with_capacity(k);
            Self::dfs(root, &mut res, k);
            res[k - 1]
        }
    }
    
    

All Problems

All Solutions