Project Euler Problem 303 - Solved

Solution to ProjectEuler's problem 303 in Python using a straight forward approach.

The statement for this problem can be found here.

In order to solve this problem in a reasonable time you should:

  • Never repeat cases that will yield the same results.
  • The first solution that fits is the correct one.
  • Efficiently advance through the quest.

Let's take each point separately.

Not repeating cases

To achieve this is important to notice that whenever you are multiplying 2 numbers the leftmost numbers of the results will only be affected by the n leftmost digits of the second argument where n equals to the number of digits of the first number plus 1.

So we should be collecting those digits, and before processing a possible multiplicand we check that its first n digits were not previously processed.

The first solution that fits is the correct one

This condition implies that the possible answers are processed in increasing order independently of how they are generated.

I made it possible by storing the possible answers in a min-heap, sorted by the length of the multiplier and then for the digits of the multiplier.

In each iteration I pop the first element, check if that is the solution and then if its not I push each of the possible solutions that can be diverted from that multiplier.

Efficiently advance through the quest

All the cases that are processed are based on previous possible solutions. Doing the multiplication each time is expensive. To gain processing time, I store in the tuple that goes into the min-heap the result of the multiplication for that multiplier.

When its processed and it is not the answer, instead of calculating each multiplication for the deriving cases I just add to the current multiplication the new digit * number being process * 10 ** (size of current multiplier - 1) which is a lot less expensive.

Final Solution

I coded a Python program that solves the problem, the file is problem303.py:

import heapq

LIMIT = 10000

mult_digit = {}
for i in range(10):
    mult_digit[i] = {}
    for j in range(10):
        if (i * j) % 10 not in mult_digit[i]:
            mult_digit[i][(i * j) % 10] = []
        mult_digit[i][(i * j) % 10].append(j)

def find_mult(i):
    str_i = str(i)
    heap = []
    for r in range(3):
        heap.extend((2, [j], i * j) for j in mult_digit[int(str_i[-1])].get(r, []))
    heapq.heapify(heap)
    used_set = set()
    while True:
        next_affected, list_n, parcial_mult = heapq.heappop(heap)

        if list_n[0] != 0 and max(str(parcial_mult)) < '3':
            return int(''.join(map(str, list_n)))

        if tuple(list_n[:len(str_i)+1]) in used_set:
            continue
        used_set.add(tuple(list_n[:len(str_i)+1]))

        str_parcial_mult = str(parcial_mult)
        try:
            s = int(str_parcial_mult[-next_affected])
        except IndexError:
            s = 0
        for d in range(3):
            d -= s
            d %= 10
            for mult in mult_digit[int(str_i[-1])].get(d, []):
                new_list_n = [mult] + list_n
                heapq.heappush(heap, (next_affected + 1, new_list_n, parcial_mult + mult * i * (10 ** (next_affected - 1))))

if __name__ == '__main__':
    result = sum(find_mult(i) for i in range(1, LIMIT + 1))
    print("The result is:", result)

Comments powered by Talkyard.