Welcome to Subscribe On Youtube
Question
Formatted question description: https://leetcode.ca/all/230.html
230. Kth Smallest Element in a BST
Level
Medium
Description
Given a binary search tree, write a function kthSmallest
to find the kth smallest element in it.
Note:
You may assume k is always valid, 1 ≤ k ≤ BST’s total elements.
Example 1:
Input: root = [3,1,4,null,2], k = 1
3
/ \
1 4
\
2
Output: 1
Example 2:
Input: root = [5,3,6,2,4,null,null,1], k = 3
5
/ \
3 6
/ \
2 4
/
1
Output: 3
Follow up:
What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?
Solution
The inorder
traversal of a binary search tree contains all values in ascending order. Therefore, obtain the inorder traversal of the binary search tree and return the k-th
smallest element.
Use a counter
, each time a node is traversed, the counter increments by 1, and when the counter reaches k
, the current node value.
Code
-
import java.util.Comparator; import java.util.PriorityQueue; public class Kth_Smallest_Element_in_a_BST { /** * Definition for a binary tree node. * public class TreeNode { * int val; * TreeNode left; * TreeNode right; * TreeNode(int x) { val = x; } * } */ class Solution { PriorityQueue<Integer> heap = new PriorityQueue<>(new Comparator<Integer>() { @Override public int compare(Integer o1, Integer o2) { return o2 - o1; } }); public int kthSmallest(TreeNode root, int k) { dfs(root, k); return heap.peek(); } private void dfs(TreeNode root, int k) { if (root == null) { return; } // maintain heap if (heap.size() < k) { heap.offer(root.val); // followup question, heap.remove() is by object, not index. // so if delete operation, just remove element from both tree and heap } else { int val = root.val; if (val < heap.peek()) { heap.poll(); heap.offer(val); } } dfs(root.left, k); dfs(root.right, k); } } }
-
// OJ: https://leetcode.com/problems/kth-smallest-element-in-a-bst/ // Time: O(N) // Space: O(H) class Solution { public: int kthSmallest(TreeNode* root, int k) { function<int(TreeNode*)> inorder = [&](TreeNode *root) { if (!root) return -1; int val = inorder(root->left); if (val != -1) return val; if (--k == 0) return root->val; return inorder(root->right); }; return inorder(root); } };
-
# Definition for a binary tree node. # class TreeNode: # def __init__(self, val=0, left=None, right=None): # self.val = val # self.left = left # self.right = right class Solution: def kthSmallest(self, root: Optional[TreeNode], k: int) -> int: stk = [] while root or stk: if root: stk.append(root) root = root.left else: root = stk.pop() k -= 1 if k == 0: return root.val root = root.right ############ # Definition for a binary tree node. # class TreeNode(object): # def __init__(self, x): # self.val = x # self.left = None # self.right = None class Solution(object): def kthSmallest(self, root, k): """ :type root: TreeNode :type k: int :rtype: int """ stack = [(1, root)] while stack: cmd, p = stack.pop() if not p: continue if cmd == 0: k -= 1 if k == 0: return p.val else: stack.append((1, p.right)) stack.append((0, p)) stack.append((1, p.left))
Solution - Follow up
Each recursion traverses all
the nodes of the left subtree to calculate the number of operations is not efficient, so the structure
of the original tree node should be modified, to save the number of nodes
including the current node and its left and right subtrees.
In this way, the total number of any left subtree nodes can be quickly obtained to quickly locate the target value.
Define the new node structure, and then generate a new tree, or use a recursive method to generate a new tree. Pay attention to the count value of the generated node to add the count value of its left and right child nodes.
Then in the function for finding the k-th
smallest element, first generate a new tree, and then call the recursive function.
In a recursive function, you cannot directly access the count value of the left child node, because the left child node does not necessarily exist, so you must first judge
- If the left child node exists, then the operation is the same as the above solution
- If it does not exist
- when
k
is 1, the current node value is directly returned - otherwise, the recursive function is called on the right child node, and k is decremented by 1
- when
-
class Solution_followUp { public int kthSmallest(TreeNode root, int k) { MyTreeNode node = build(root); return dfs(node, k); } class MyTreeNode { int val; int count; // key point to add up and find k-th element MyTreeNode left; MyTreeNode right; MyTreeNode(int x) { this.val = x; this.count = 1; } }; MyTreeNode build(TreeNode root) { if (root == null) return null; MyTreeNode node = new MyTreeNode(root.val); node.left = build(root.left); node.right = build(root.right); if (node.left != null) node.count += node.left.count; if (node.right != null) node.count += node.right.count; return node; } int dfs(MyTreeNode node, int k) { if (node.left != null) { int cnt = node.left.count; if (k <= cnt) { return dfs(node.left, k); } else if (k > cnt + 1) { return dfs(node.right, k - 1 - cnt); // -1 is to exclude current root } else { // k == cnt + 1 return node.val; } } else { if (k == 1) return node.val; return dfs(node.right, k - 1); } } }
-
// Follow up class Solution_followUp { public: struct MyTreeNode { int val; int count; MyTreeNode *left; MyTreeNode *right; MyTreeNode(int x) : val(x), count(1), left(NULL), right(NULL) {} }; MyTreeNode* build(TreeNode* root) { if (!root) return NULL; MyTreeNode *node = new MyTreeNode(root->val); node->left = build(root->left); node->right = build(root->right); if (node->left) node->count += node->left->count; if (node->right) node->count += node->right->count; return node; } int kthSmallest(TreeNode* root, int k) { MyTreeNode *node = build(root); return helper(node, k); } int helper(MyTreeNode* node, int k) { if (node->left) { int cnt = node->left->count; if (k <= cnt) { return helper(node->left, k); } else if (k > cnt + 1) { return helper(node->right, k - 1 - cnt); } return node->val; } else { if (k == 1) return node->val; return helper(node->right, k - 1); } } };
-
# Definition for a binary tree node. # class TreeNode: # def __init__(self, val=0, left=None, right=None): # self.val = val # self.left = left # self.right = right class Solution_followUp: def kthSmallest(self, root: TreeNode, k: int) -> int: node = self.build(root) return self.dfs(node, k) class MyTreeNode: def __init__(self, x): self.val = x self.count = 1 self.left = None self.right = None def build(self, root: TreeNode) -> MyTreeNode: if not root: return None node = self.MyTreeNode(root.val) node.left = self.build(root.left) node.right = self.build(root.right) if node.left: node.count += node.left.count if node.right: node.count += node.right.count return node def dfs(self, node: MyTreeNode, k: int) -> int: if node.left: cnt = node.left.count if k <= cnt: return self.dfs(node.left, k) elif k > cnt + 1: return self.dfs(node.right, k - 1 - cnt) else: return node.val else: if k == 1: return node.val return self.dfs(node.right, k - 1) ############ # Definition for a binary tree node. # class TreeNode(object): # def __init__(self, x): # self.val = x # self.left = None # self.right = None class Solution(object): def kthSmallest(self, root, k): """ :type root: TreeNode :type k: int :rtype: int """ stack = [(1, root)] while stack: cmd, p = stack.pop() if not p: continue if cmd == 0: k -= 1 if k == 0: return p.val else: stack.append((1, p.right)) stack.append((0, p)) stack.append((1, p.left))