# Question

Formatted question description: https://leetcode.ca/all/230.html

# 230. Kth Smallest Element in a BST

Medium

## Description

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Note:

You may assume k is always valid, 1 ≤ k ≤ BST’s total elements.

Example 1:

Input: root = [3,1,4,null,2], k = 1
3
/ \
1   4
\
2
Output: 1


Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3
5
/ \
3   6
/ \
2   4
/
1
Output: 3


What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

## Solution

The inorder traversal of a Binary Search Tree (BST) naturally produces a sequence of values in ascending order. To find the k-th smallest element in a BST, one can perform an inorder traversal and select the k-th element encountered.

A counter is employed to keep track of the traversal order. As each node is visited during the traversal, the counter is incremented. Upon the counter reaching k, the value of the current node is identified as the k-th smallest.

## Solution Follow up - Improved Solution for Efficiency

Directly traversing all nodes of the left subtree to determine the number of nodes each time is not the most efficient approach. To optimize, the structure of the original BST nodes can be enhanced to include a count of all nodes under a given node, including itself, its left subtree, and its right subtree.

This modification allows for the swift determination of the total node count within any left subtree, facilitating quicker identification of the target value.

### Implementing the Enhanced Node Structure

• New Node Structure: Introduce an augmented node structure that maintains a count of the total nodes rooted at that node.
• Tree Transformation: Transform the original tree into one that adheres to the new structure, either by constructing a new tree or employing a recursive strategy to update the original tree in place. Ensure the node’s count includes the sum of the counts of its left and right children.

### Finding the k-th Smallest Element with the Enhanced Tree

• Tree Preparation: Begin by transforming the BST into the enhanced structure.
• Recursive Search: In the search for the k-th smallest element, a recursive function is used. However, direct access to a left child node’s count is not guaranteed since the left child might not exist.
• With Left Child: If the left child exists, proceed similarly to the initial solution, using the node count to guide the search.
• Without Left Child:
• If k is 1, return the current node’s value immediately.
• Otherwise, continue the search recursively on the right child, decrementing k by 1.

This follow-up solution significantly reduces the number of traversals needed by leveraging the node count information, making it more efficient, especially for large trees.

# Code

• import java.util.Comparator;
import java.util.PriorityQueue;

public class Kth_Smallest_Element_in_a_BST {
/**
* Definition for a binary tree node.
* public class TreeNode {
*     int val;
*     TreeNode left;
*     TreeNode right;
*     TreeNode(int x) { val = x; }
* }
*/
class Solution {

PriorityQueue<Integer> heap = new PriorityQueue<>(new Comparator<Integer>() {
@Override
public int compare(Integer o1, Integer o2) {
return o2 - o1;
}
});

public int kthSmallest(TreeNode root, int k) {

dfs(root, k);

return heap.peek();
}

private void dfs(TreeNode root, int k) {

if (root == null) {
return;
}

// maintain heap
if (heap.size() < k) {
heap.offer(root.val);

// followup question, heap.remove() is by object, not index.
// so if delete operation, just remove element from both tree and heap

} else {
int val = root.val;
if (val < heap.peek()) {
heap.poll();
heap.offer(val);
}
}
dfs(root.left, k);
dfs(root.right, k);
}
}

}

class Solution_followUp {
public int kthSmallest(TreeNode root, int k) {
MyTreeNode node = build(root);
return dfs(node, k);
}

class MyTreeNode {
int val;
int count; // key point to add up and find k-th element
MyTreeNode left;
MyTreeNode right;
MyTreeNode(int x) {
this.val = x;
this.count = 1;
}
};

MyTreeNode build(TreeNode root) {
if (root == null) return null;
MyTreeNode node = new MyTreeNode(root.val);
node.left = build(root.left);
node.right = build(root.right);
if (node.left != null) node.count += node.left.count;
if (node.right != null) node.count += node.right.count;
return node;
}

int dfs(MyTreeNode node, int k) {
if (node.left != null) {
int cnt = node.left.count;
if (k <= cnt) {
return dfs(node.left, k);
} else if (k > cnt + 1) {
return dfs(node.right, k - 1 - cnt); // -1 is to exclude current root
} else { // k == cnt + 1
return node.val;
}
} else {
if (k == 1) return node.val;
return dfs(node.right, k - 1);
}
}
}

############

/**
* Definition for a binary tree node.
* public class TreeNode {
*     int val;
*     TreeNode left;
*     TreeNode right;
*     TreeNode() {}
*     TreeNode(int val) { this.val = val; }
*     TreeNode(int val, TreeNode left, TreeNode right) {
*         this.val = val;
*         this.left = left;
*         this.right = right;
*     }
* }
*/
class Solution {
public int kthSmallest(TreeNode root, int k) {
Deque<TreeNode> stk = new ArrayDeque<>();
while (root != null || !stk.isEmpty()) {
if (root != null) {
stk.push(root);
root = root.left;
} else {
root = stk.pop();
if (--k == 0) {
return root.val;
}
root = root.right;
}
}
return 0;
}
}

• // OJ: https://leetcode.com/problems/kth-smallest-element-in-a-bst/
// Time: O(N)
// Space: O(H)
class Solution {
public:
int kthSmallest(TreeNode* root, int k) {
function<int(TreeNode*)> inorder = [&](TreeNode *root) {
if (!root) return -1;
int val = inorder(root->left);
if (val != -1) return val;
if (--k == 0) return root->val;
return inorder(root->right);
};
return inorder(root);
}
};

• # Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
stk = []
while root or stk:
if root:
stk.append(root)
root = root.left
else:
root = stk.pop()
k -= 1
if k == 0:
return root.val
root = root.right

############

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution_followUp:
def kthSmallest(self, root: TreeNode, k: int) -> int:
node = self.build(root)
return self.dfs(node, k)

class MyTreeNode:
def __init__(self, x):
self.val = x
self.count = 1 # default 1
self.left = None
self.right = None

def build(self, root: TreeNode) -> MyTreeNode:
if not root:
return None
node = self.MyTreeNode(root.val) # default count already is 1
node.left = self.build(root.left)
node.right = self.build(root.right)
if node.left:
node.count += node.left.count
if node.right:
node.count += node.right.count
return node

def dfs(self, node: MyTreeNode, k: int) -> int:
if node.left:
cnt = node.left.count
if k < cnt + 1:
return self.dfs(node.left, k)
elif k > cnt + 1:
return self.dfs(node.right, k - 1 - cnt)
else: # k == cnt+1
return node.val
else:
if k == 1: # cannot move to beginning of dfs()
return node.val
return self.dfs(node.right, k - 1)


• /**
* Definition for a binary tree node.
* type TreeNode struct {
*     Val int
*     Left *TreeNode
*     Right *TreeNode
* }
*/
func kthSmallest(root *TreeNode, k int) int {
stk := []*TreeNode{}
for root != nil || len(stk) > 0 {
if root != nil {
stk = append(stk, root)
root = root.Left
} else {
root = stk[len(stk)-1]
stk = stk[:len(stk)-1]
k--
if k == 0 {
return root.Val
}
root = root.Right
}
}
return 0
}

• /**
* Definition for a binary tree node.
* class TreeNode {
*     val: number
*     left: TreeNode | null
*     right: TreeNode | null
*     constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
*         this.val = (val===undefined ? 0 : val)
*         this.left = (left===undefined ? null : left)
*         this.right = (right===undefined ? null : right)
*     }
* }
*/

function kthSmallest(root: TreeNode | null, k: number): number {
const dfs = (root: TreeNode | null) => {
if (root == null) {
return -1;
}
const { val, left, right } = root;
const l = dfs(left);
if (l !== -1) {
return l;
}
k--;
if (k === 0) {
return val;
}
return dfs(right);
};
return dfs(root);
}


• // Definition for a binary tree node.
// #[derive(Debug, PartialEq, Eq)]
// pub struct TreeNode {
//   pub val: i32,
//   pub left: Option<Rc<RefCell<TreeNode>>>,
//   pub right: Option<Rc<RefCell<TreeNode>>>,
// }
//
// impl TreeNode {
//   #[inline]
//   pub fn new(val: i32) -> Self {
//     TreeNode {
//       val,
//       left: None,
//       right: None
//     }
//   }
// }
use std::rc::Rc;
use std::cell::RefCell;
impl Solution {
fn dfs(root: Option<Rc<RefCell<TreeNode>>>, res: &mut Vec<i32>, k: usize) {
if let Some(node) = root {
let mut node = node.borrow_mut();
Self::dfs(node.left.take(), res, k);
res.push(node.val);
if res.len() >= k {
return;
}
Self::dfs(node.right.take(), res, k);
}
}
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let k = k as usize;
let mut res: Vec<i32> = Vec::with_capacity(k);
Self::dfs(root, &mut res, k);
res[k - 1]
}
}