Welcome to Subscribe On Youtube
Formatted question description: https://leetcode.ca/all/1740.html
1740. Find Distance in a Binary Tree
Level
Medium
Description
Given the root of a binary tree and two integers p
and q
, return the distance between the nodes of value p
and value q
in the tree.
The distance between two nodes is the number of edges on the path from one to the other.
Example 1:
Input: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 0
Output: 3
Explanation: There are 3 edges between 5 and 0: 5-3-1-0.
Example 2:
Input: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 7
Output: 2
Explanation: There are 2 edges between 5 and 7: 5-2-7.
Example 3:
Input: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 5
Output: 0
Explanation: The distance between a node and itself is 0.
Constraints:
- The number of nodes in the tree is in the range
[1, 10^4]
. 0 <= Node.val <= 10^9
- All
Node.val
are unique. p
andq
are values in the tree.
Solution
Related: 236-Lowest-Common-Ancestor-of-a-Binary-Tree/
First, find the lowest common ancestor of the two nodes with values p
and q
. Next, calculate the distances from the lowest common ancestor to the two nodes with values p
and q
, respectively. Finally, calculate the sum of the two distances and return.
Alternative-1
hashmap to store distance of every node-pair during finding LCA, then just map look up.
Alternative-2
Dist(n1, n2) = Dist(root, n1) + Dist(root, n2) - 2 * Dist(root, lca)
- ‘n1’ and ‘n2’ are the two given keys
- ‘root’ is root of given Binary Tree.
- ‘lca’ is lowest common ancestor of n1 and n2
- Dist(n1, n2) is the distance between n1 and n2.
-
/** * Definition for a binary tree node. * public class TreeNode { * int val; * TreeNode left; * TreeNode right; * TreeNode() {} * TreeNode(int val) { this.val = val; } * TreeNode(int val, TreeNode left, TreeNode right) { * this.val = val; * this.left = left; * this.right = right; * } * } */ // alternative: hashmap to store distance of every node-pair during finding LCA, then just map look up class Solution { public int findDistance(TreeNode root, int p, int q) { if (p == q) return 0; TreeNode ancestor = lowestCommonAncestor(root, p, q); return getDistance(ancestor, p) + getDistance(ancestor, q); } public TreeNode lowestCommonAncestor(TreeNode root, int p, int q) { if (root == null) return null; if (root.val == p || root.val == q) return root; TreeNode left = lowestCommonAncestor(root.left, p, q); TreeNode right = lowestCommonAncestor(root.right, p, q); if (left != null && right != null) return root; return left == null ? right : left; } public int getDistance(TreeNode node, int val) { if (node == null) return -1; if (node.val == val) return 0; int leftDistance = getDistance(node.left, val); int rightDistance = getDistance(node.right, val); int subDistance = Math.max(leftDistance, rightDistance); return subDistance >= 0 ? subDistance + 1 : -1; } } ############ /** * Definition for a binary tree node. * public class TreeNode { * int val; * TreeNode left; * TreeNode right; * TreeNode() {} * TreeNode(int val) { this.val = val; } * TreeNode(int val, TreeNode left, TreeNode right) { * this.val = val; * this.left = left; * this.right = right; * } * } */ class Solution { public int findDistance(TreeNode root, int p, int q) { TreeNode g = lca(root, p, q); return dfs(g, p) + dfs(g, q); } private int dfs(TreeNode root, int v) { if (root == null) { return -1; } if (root.val == v) { return 0; } int left = dfs(root.left, v); int right = dfs(root.right, v); if (left == -1 && right == -1) { return -1; } return 1 + Math.max(left, right); } private TreeNode lca(TreeNode root, int p, int q) { if (root == null || root.val == p || root.val == q) { return root; } TreeNode left = lca(root.left, p, q); TreeNode right = lca(root.right, p, q); if (left == null) { return right; } if (right == null) { return left; } return root; } }
-
// OJ: https://leetcode.com/problems/find-distance-in-a-binary-tree/ // Time: O(N) // Space: O(H) class Solution { bool findPath(TreeNode *root, int target, vector<TreeNode*> &path) { if (!root) return false; path.push_back(root); if (root->val == target || findPath(root->left, target, path) || findPath(root->right, target, path)) return true; path.pop_back(); return false; } public: int findDistance(TreeNode* root, int p, int q) { if (p == q) return 0; vector<TreeNode*> a, b; findPath(root, p, a); findPath(root, q, b); int i = 0; while (i < a.size() && i < b.size() && a[i] == b[i]) ++i; return a.size() + b.size() - 2 * i; } };
-
# Definition for a binary tree node. # class TreeNode: # def __init__(self, val=0, left=None, right=None): # self.val = val # self.left = left # self.right = right class Solution: def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int: def lca(root, p, q): if root is None or root.val in [p, q]: return root left = lca(root.left, p, q) right = lca(root.right, p, q) if left is None: return right if right is None: return left return root def dfs(root, v): if root is None: return -1 if root.val == v: return 0 left, right = dfs(root.left, v), dfs(root.right, v) if left == right == -1: return -1 return 1 + max(left, right) g = lca(root, p, q) return dfs(g, p) + dfs(g, q) ############# class Solution: def findDistance(self, root: TreeNode, p: int, q: int) -> int: # initialize variables to hold the depth of the target nodes and their common ancestor p_depth, q_depth, lca = -1, -1, None def dfs(node: TreeNode, depth: int) -> None: nonlocal p_depth, q_depth, lca if not node: return # check if the current node is one of the target nodes if node.val == p: p_depth = depth elif node.val == q: q_depth = depth # if both target nodes are found, return the LCA and stop searching if p_depth != -1 and q_depth != -1 and not lca: lca = node return # traverse the left and right subtrees dfs(node.left, depth + 1) dfs(node.right, depth + 1) # if the current node is the LCA, return it if node == lca: return # if one of the target nodes is found in the left or right subtree, update its depth if p_depth != -1 and q_depth == -1 and (node.left and node.left.val == q or node.right and node.right.val == q): q_depth = depth + 1 elif p_depth == -1 and q_depth != -1 and (node.left and node.left.val == p or node.right and node.right.val == p): p_depth = depth + 1 dfs(root, 0) # if one or both target nodes are not found, return -1 if p_depth == -1 or q_depth == -1: return -1 # return the distance between the target nodes return abs(p_depth - q_depth)
-
/** * Definition for a binary tree node. * type TreeNode struct { * Val int * Left *TreeNode * Right *TreeNode * } */ func findDistance(root *TreeNode, p int, q int) int { var lca func(root *TreeNode, p int, q int) *TreeNode lca = func(root *TreeNode, p int, q int) *TreeNode { if root == nil || root.Val == p || root.Val == q { return root } left, right := lca(root.Left, p, q), lca(root.Right, p, q) if left == nil { return right } if right == nil { return left } return root } var dfs func(root *TreeNode, v int) int dfs = func(root *TreeNode, v int) int { if root == nil { return -1 } if root.Val == v { return 0 } left, right := dfs(root.Left, v), dfs(root.Right, v) if left == -1 && right == -1 { return -1 } return 1 + max(left, right) } g := lca(root, p, q) return dfs(g, p) + dfs(g, q) } func max(a, b int) int { if a > b { return a } return b }