Problem statement

Solution 1

We need to find the longest sum of not-decreasing subsequence.


It is O(n^2) for time and O(n) for space.


class Solution:
    def solve(self, R, A):
        P = sorted([(a, r) for a, r in zip(A, R)])
        arr = [j for _, j in P]
        n = len(arr)

        def dp(i):
            if i == -1: return 0
            ans = arr[i]
            for j in range(i):
                if arr[j] <= arr[i]:
                    ans = max(ans, dp(j) + arr[i])
            return ans

        return max(dp(i) for i in range(n))

Solution 2

There is better solution, using BIT. The idea is to use BIT with max queries and add elements one by one starting from small to big.


It is O(n log n) for time and O(n) for space.


class BIT:
    def __init__(self, n):
        self.maxs = [0] * (n+1)
    def update(self, i, delta):
        while i < len(self.maxs):
            self.maxs[i] = max(self.maxs[i], delta)
            i += i & (-i)
    def query(self, i):
        res = 0
        while i > 0:
            res = max(res, self.maxs[i])
            i -= i & (-i)
        return res
class Solution:
    def solve(self, R, A):
        P = sorted([(a, r) for a, r in zip(A, R)])
        arr = [j for _, j in P]
        n = len(arr)
        d = {x: i+1 for i, x in enumerate(sorted(set(arr)))}
        m = len(d)
        bit = BIT(m)
        for x in arr:
            bit.update(d[x], bit.query(d[x]) + x)
        return bit.query(m)