Binary Search Pitfalls illustrated by examples
Binary search is very difficult! It’s just so hard to write binary search correctly. There are all these off-by-one errors which will cause
1 elements in search space(array) to be skipped
2 index out of bounds
3 loop stuck forever
These errors are indeed very annoying and it gets me every single time. If I get tricked by these in an interview, I’d imagine it to be frustrating to debug and reason about under stress.
If you have the same experiences, I highly encourage you to read through the references at the end of this article. Here, I’ll walk through a few binary search problems, how they could go wrong and what should be the right way to implement them.
Example 1: Find Smallest Letter Greater Than Target
For problem description, please refer to Leetcode
My first attempt at the problem:
class Solution:
def nextGreatestLetter(self, letters, target)
left = 0
right = len(letters) - 1
while left <= right:
mid = left + (right-left)//2
if letters[mid] <= target:
left = mid + 1
else: # letters[mid] > target:
right = mid
if letters[right] > target:
return letters[right] return letters[0]
assigning initial range
I always like to set left = 0 and right = last element of array, so that they both point to some existing element in the array.
terminate condition
while left <= right:
here I am saying the loop should break when the range we’re looking at has no elements. This is ok so far, after all the terminate condition is just a way to put a stop to the binary search process.
loop body
Here’s where we decide how to do the binary search.
if letters[mid] <= target:
left = mid + 1
In this if branch, we’re saying if the middle letter we’re looking at is smaller or equal than target, then the answer must not be letters[mid] or anything on the left hand side of it. The answer must be in the range [mid + 1, right], so we set left = mid + 1
else: # letters[mid] > target:
right = mid
In this else branch, we’re saying if the middle letter we’re looking at is greater than target, then we can conclude that this middle letter could be the answer, and there are more potential answers on the left hand side of it. The answer must be in the range [left, mid]. Here, setting right=mid-1 is wrong, be careful!
So far so good, now consider when left == right(therefore left==mid==right), and letters[mid] > target is true.
In this case, we set right=mid and restart the loop. Upon restarting the loop, left is still equal to right, and so we get stuck in the loop forever!
An easy fix for this problem is to add a check to see if left==right then break loop. Or, we could just change the terminate condition to
while left < right
so when left == right, the loop terminates itself
class Solution:
def nextGreatestLetter(self, letters, target):
"""
left = 0
right = len(letters) - 1
while left < right:
mid = left + (right-left)//2
if letters[mid] <= target:
left = mid + 1
else:# letters[mid] > target:
right = mid
if letters[right] > target:
return letters[right]
return letters[0]
"""
after loop terminates
once loop terminates, we know left==right, and we know letters[right] hasn’t been looked at yet(i.e we have narrowed down search space to just 1 element). So we will make a simple check to verify that it’s indeed greater than the target. If not, we should return letters[0] according to the problem statement.
summary: terminate condition should adapt so that the loop actually terminates. after loop terminates, look at the elements we’ve narrowed down to and do simple check.
Example 2: Time Based Key-Value Store
Again, please refer to Leetcode for problem description
Here’s my first failed attempt at the problem. Again, the basic idea is right, but it’ll actually get stuck in the forloop forever.
class TimeMap:
def __init__(self):
"""
Initialize your data structure here.
"""
self.store = defaultdict(list) def set(self, key: str, value: str, timestamp: int) -> None:
self.store[key].append((timestamp, value)) def get(self, key: str, timestamp: int) -> str:
if key not in self.store:
return ""
arr = self.store[key]
left = 0
right = len(arr) - 1 while left < right:
mid = left + (right-left)//2
mid_timestamp = arr[mid][0]
if mid_timestamp <= timestamp:
left = mid
else:
right = mid - 1
if arr[left][0] <= timestamp:
return arr[left][1]
return ""
For this code, consider input
arr = [1,4]
when we call get function at timestamp=4, answer should be 4
but our code will get stuck because left will be set to 0 again and again and again…
If you remember, this code looks very similar to our solution to example1:
while left < right:
mid = left + (right-left)//2
if letters[mid] <= target:
left = mid + 1
else:# letters[mid] > target:
right = mid
why didn’t this code get stuck? well because when arr is of size 2, either branch will further reduce the search space to size 1 thus exiting the loop. This is not true in TimeMap solution. A fix look like this:
def get(self, key: str, timestamp: int) -> str:
if key not in self.store:
return ""
arr = self.store[key]
left = 0
right = len(arr) - 1
ans = -1
while left < right:
mid = left + (right-left)//2
mid_timestamp = arr[mid][0]
if mid_timestamp <= timestamp:
ans = left
left = mid + 1
else:
right = mid - 1
if arr[left][0] <= timestamp:
return arr[left][1]
if ans != -1:
return arr[ans][1]
return ""
Here, we’re skipping arr[mid] and recording any previously seen ‘left’ and take that into account in the end.
Note that this getting stuck in forloop situation only happens at the corner case when array size becomes 2. In any other cases, the code is perfectly fine.
In example 1, we actually have similar corner case where our forloop get stuck when array size becomes 1. Our solution was to change the terminate condition to
while left < right
and take care of the array with size 1 outside of the loop. Actually here we can do something similar!
We could change the terminate condition to
while left+1 < right
and takes care of the array with size 2 outside of the loop. Complete code below:
def get(self, key: str, timestamp: int) -> str:
if key not in self.store:
return ""
arr = self.store[key]
left = 0
right = len(arr) - 1
while left + 1 < right:
mid = left + (right-left)//2
mid_timestamp = arr[mid][0]
if mid_timestamp <= timestamp:
left = mid
else:
right = mid - 1
if arr[right][0] <= timestamp:
return arr[right][1]
if arr[left][0] <= timestamp:
return arr[left][1]
return ""
summary: pay attention to what happens when your array size becomes 1/2(depending on your terminate condition), will your forloop really terminate? or will it get stuck forever!
References:
https://zhu45.org/posts/2018/Jan/12/how-to-write-binary-search-correctly/