Problem statement

https://binarysearch.com/problems/K-Subsequence/

Solution

The idea is to use dp[i] be answer for nums[:i+1] and also keep segment tree, where we have for each value STree[value] is the maximum length of subsequene which ends with value.

Complexity

It is O(n log M), where n is length of nums and M = 2*10^9.

Code 1

class SegmentTree:
    def __init__(self):
        self.T = defaultdict(int)  

    def update(self, v, tl, tr, pos, h):
        if tl == tr: 
            self.T[v] = max(self.T[v], h)
        else:
            tm = (tl + tr)//2
            if pos <= tm:
                self.update(v*2, tl, tm, pos, h)
            else:
                self.update(v*2+1, tm+1, tr, pos, h)
            self.T[v] = max(self.T[v*2], self.T[v*2+1])

    def query(self, v, tl, tr, l, r):
        if l > r: return 0
        if l == tl and r == tr: return self.T[v]
        tm = (tl + tr)//2
        return max(self.query(v*2, tl, tm, l, min(r, tm)), self.query(v*2+1, tm+1, tr, max(l, tm+1), r))

class Solution:
    def solve(self, nums, k):
        if not nums: return 0
        nums_min, nums_max = min(nums), max(nums)
        dp = []
        N = nums_max - nums_min
        nums = [i - nums_min for i in nums]
        STree = SegmentTree()
        for i, num in enumerate(nums):
            dp += [STree.query(1, 0, N + 1, max(0, num - k), min(num + k, N + 1)) + 1]
            STree.update(1, 0, N + 1, num, dp[i])
        
        return max(dp)

Code 2

Here is iterative version of tree with the same complexity, but it works faster and get AC verdict. Notice that we need to shift our numbers, that is if we have values in our array in [0, ..., N], then we need to create SegmentTree(2*N).

class SegmentTree:
    def __init__(self, Q):
        self.T = defaultdict(int)
        self.Q = Q

    def query(self, low, high):
        result = 0
        low, high = low + self.Q, high + self.Q
        while low <= high:
            if low % 2 == 1:
                result = max(result, self.T[low])
                low += 1
            if high % 2 == 0:
                result = max(result, self.T[high])
                high -= 1
            low //= 2
            high //= 2
        return result

    def update(self, index, val):
        index += self.Q
        self.T[index] = max(self.T[index], val)
        index //= 2
        while index >= 1:
            self.T[index] = max(self.T[2 * index], self.T[2 * index + 1])
            index //= 2

class Solution:
    def solve(self, nums, k):
        if not nums: return 0
        nums_min, nums_max = min(nums), max(nums)
        dp = []
        N = nums_max - nums_min
        nums = [i - nums_min + 1 for i in nums]
        STree = SegmentTree(2*N)
        for i, num in enumerate(nums):
            dp += [STree.query(max(1, num - k), min(num + k, N+1)) + 1]
            STree.update(num, dp[i])
        
        return max(dp)