Problem statement

Solution 1 (text by awice)

Let P[i] be the prefix sum sum(A[:i]). We can calculate that when we move an element A[i] = x over to some position j, it results in a change in power of P[i] - P[j] - (i - j) * x, which we can rewrite as (-i * x + P[i]) - (-j * x + P[j]).

Now, let’s say the j-th line is eval_at(j, x) = -j * x + P[j]. For each such (i, x), lets try to calculate the minimum possible eval_at(j, x) over all lines j. We consider the lines in decreasing order of slope - that way, any time we have some lines intersect at x_int = intersection(j1, j2) [with j1 < j2], it means that eval_at(j2, x) will be a lower value when x >= x_int.

At the end, we have increasing x intersections hull, and the associated line indexes indexes. Now for each x, we find the corresponding line indexes[...] in our hull that will yield the lowest value of eval_at(j, x).

For more information, google search “convex hull trick”.

See also similar Leetcode Problem 0644 Maximum Average Subarray II, where one of the possible solutions uses the same trick.


It is O(n) for time and space


class Solution:
    def solve(self, A):
        P = [0] + list(accumulate(A))
        base = sum(i * x for i, x in enumerate(A, 1))

        def eval_at(j, x):
            return -j * x + P[j]

        def intersection(j1, j2):
            return (P[j2] - P[j1]) / (j2 - j1)

        hull = [-1]
        indexes = [0]
        for j in range(1, len(P)):
            while hull and intersection(indexes[-1], j) <= hull[-1]:

            hull.append(intersection(indexes[-1], j))

        ans = base
        for i, x in enumerate(A):
            j = bisect.bisect(hull, x)
            j = max(j - 1, 0)
            ans = max(ans, base + eval_at(i, x) - eval_at(indexes[j], x))
        return ans % (10 ** 9 + 7)

Solution 2

Another solution is to use again eval_at function, but know we will use ternary search to fine the optimum of convex function. Notice that function

\(f(x) = \max\limits_{i=1..n} \texttt{eval_at}(i, x)\) is convex. Why the rest of the solution works, I am not sure, something about convex duality may be? Also it is possible that it is wrong.


It is O(n log n).


class Solution:
    def solve(self, A):
        base = ans = sum(i * x for i, x in enumerate(A, 1))
        n = len(A)
        P = [0] + list(accumulate(A))
        def eval_at(j, x):
            return -j * x + P[j]

        for i, x in enumerate(A):
            l, r = 0, n
            while r - l >= 4:
                m1, m2 = l+(r-l)//3, r-(r-l)//3
                if eval_at(m1, x) < eval_at(m2, x):
                    r = m2
                    l = m1

            for j in range(l, r+1):
                ans = max(ans, base + eval_at(i, x) - eval_at(j, x))

        return ans % (10**9 + 7)