Getting To Know Scala: Project Euler Primes

Prime numbers are not my thing, but generating them is a common task in the early Project Euler problems. The one algorithm I know for generating primes is the Sieve of Eratosthenes, which I defined in Scala as:

[sourcecode lang="scala"]
def successor(n : Double) : Stream[Double] = Stream.cons(n, successor(n + 1))

def sieve(nums : Stream[Double]) : Stream[Double] = Stream.cons(nums.head, sieve ((nums tail) filter (x => x % nums.head != 0)) )

val prime_stream = sieve(successor(2))
[/sourcecode]

The first function is the only function that I've ever written that I'm sure is properly "functional." It's stuck in my head from circa 1982 LISP. It uses Scala's Stream class, which is like a List but is "lazily evaluated," in other words, it only calculates the next value in the List when it's needed (the Stream pattern is to create a List whose head is the next value and whose tail is a recursive call that, when executed will produce the next value).

The 2nd function sieve is my take on the Sieve of Eratosthenes. It too returns a Stream of primes. (By the way, the reason I use Double rather than an Int or Long is that one of the early Project Euler problems involves a prime larger than LONG_MAX.)

In case you're not familiar with the algorithm, the Sieve is conceptually simple. Begin with a list containing all positive integers starting at 2 (the first prime) [2, 3, 4, ...] . Remove from the list every multiple of your current prime. The first number remaining is the next prime. For instance, after removing [2, 4, 6, ... ], the first number remaining is 3. Prime! So remove [3, 6 (already removed), 9, ... ]. Since 4 was removed as a multiple of 2, the next available is 5. Prime! Remove [5, 10 (already removed), 15 (already removed), ...] ...

The 7th Project Euler problem is "What is the 10001st prime number?" Unfortunately,

[sourcecode]
scala> prime_stream take 10001 print
2.0, 3.0, 5.0, 7.0, ...SNIP ... 29059.0, 29063.0, java.lang.OutOfMemoryError: Java heap space
at scala.Stream\\(cons\\).apply(Stream.scala:62)
at scala.Stream.filter(Stream.scala:381)
at scala.Stream\\(\\)anonfun\\(filter\\)1.apply(Stream.scala:381)
at scala.Stream\\(\\)anonfun\\(filter\\)1.apply(Stream.scala:381)
at scala.Stream\\(cons\\)\\(anon\\)2.tail(Stream.scala:69)
at scala.Stream\\(\\)anonfun\\(filter\\)1.apply(Stream.scala:381)
at scala.Stream\\(\\)anonfun\$...
[/sourcecode]

That will never do. Obviously, I could run Scala with more heap space, but that would only be a bandage. Since a quick Google search shows that the 1000th prime number is 104,729 and I'm running out of heap space near 30K, it seems that "messing around with primes near the 10Kth mark" requires some memory optimization.

Converting the Sieve

If I really wanted to work with very large sequences of primes, I should certainly move away from the Sieve of Eratosthenes. But I'm not really interested in prime number algorithms, I'm interested in the characteristics of the Scala programming language, so I'm going to intentionally ignore better algorithms.

My first thought was "OK, I'll allocate a chunk of memory and every time I find a prime, I'll set every justFoundPrime-th bit to 1." But that would depend upon my allocated memory being sufficient to hold the nth prime. With my Google-powered knowledge that the 10001st prime is only 100K or so, that would be easy enough, but (a) it seemed like cheating and (b) it would require a magic number in my code.

My next thought was "OK, when I run out of space, I'll dynamically allocate double the space-- no, wait, I only need justFoundPrime-(2 * justFoundPrime) space, since I've already checked the numbers up to justFoundPrime."
My next thought was "And really I only need half that space, since I know 2 is a prime and I can just count by 2...And, y'know, I know 3 is prime too, so I can check--" At which point, I engaged in a mental battle over what was appropriate algorithmic behavior.

On the one hand, I didn't want to change algorithms: if I moved from the Sieve to a slightly better algorithm, then wasn't it Shameful not to move to at least a Quite Good algorithm? On the other hand, the instant I opened the door to allocating new memory, I committed to keeping around the list of already-discovered primes, since I would have to apply that list to my newly-allocated memory. But if you have a list of numbers, checking if your candidate number is a multiple of any of them can be done without consuming any additional memory. But is it the same algorithm? Isn't the Sieve fundamentally about marking spots in a big array?

Finally, I decided that checking a candidate number against the list of already-discovered primes was the Sieve algorithm, just with a smallest possible amount of memory allocation -- one number. (By the way, did you read the article in which scientists say that rational thought is just a tool for winning arguments to which you're already emotionally committed?)

Here then, is what I wrote:

[sourcecode lang="scala"]
def multiple_of = (base : Long, target : Long) => target % base == 0;

val twoFilter = multiple_of(2, _ : Long)
val threeFilter = multiple_of(3, _ : Long)

println(twoFilter(4))
println(twoFilter(5))
println(twoFilter(6))
println(threeFilter(4))
println(threeFilter(5))
println(threeFilter(6))
[/sourcecode]

The first function multiple_of is a function literal (?) that returns true if the target is a multiple of the base. The next two lines, where I define twoFilter and threeFilter are  an example of the functional idiom of partial function application (I think -- "currying" is the use of partial function application to accomplish a goal, right?).

This is an undeniably cool feature of functional languages. Without any fuss, these lines create new functions that require one less argument to have their needed context. Once you have a twoFilter, you don't need to keep the value "2" around to pass in. Which might not seem like a big win, since a function named twoFilter or threeFilter is no more compact than calling multiple_of(2, x) or multiple_of(3,x). But...

[sourcecode lang="scala" firstline="12"]
def filter = (x : Long) => multiple_of(x, _ : Long);

val fs = List(filter(2), filter(3))
for(f \<- fs){
println("Eight is a multiple of this filter: " + f(8))
}
[/sourcecode]

OK, now that's nice and compact! Now we have a new tk function literal? tk called filter and rather than have a bunch of variables called twoFilter and threeFilter and fiveFilter, we just have a List of filters. With such a list in hand, it's easy to figure out which numbers in a list are relatively prime:

[sourcecode lang="scala" firstline="18"]
def relatively_prime(fs : List[(Long)=>Boolean], target : Long) : Boolean = {
for(f \<- fs){
if(f(target)){
return false;
}
}
return true;
}

println("4 is prime? " + relatively_prime(fs, 4))
println("5 is prime? " + relatively_prime(fs, 5))

val list = List[Long](2, 3, 4, 5, 6, 7, 8, 9, 10, 11)
println(list.map(relatively_prime(fs, _)))

[/sourcecode]

Which leads to a simple recursive function to find the next prime:

[sourcecode lang="scala" firstline="32"]
def next_prime(fs : List[(Long)=>Boolean], x : Long) : Long = {
if (relatively_prime(fs, x)) {
return x
}
return next_prime(fs, x + 1)
}

println(next_prime(fs, 4))
println(next_prime(fs, 8))
[/sourcecode]

Which leads to our solution:

[sourcecode lang="scala" firstline="41"]
def primes(fs : List[(Long)=>Boolean], ps: List[Long], x : Long, nth : Long) : List[Long] = {
if(ps.size == nth){
return ps;
}

val np = next_prime(fs, x)
val sieve = fs ::: List(filter(np));
primes(sieve, ps ::: List(np), np + 1, nth);
}

println("Missing 3 because its in fs" + primes(fs, ListLong, 2, 8))
println((primes(List(filter(2)), List(2L), 2, 8) reverse) head)

def nth_prime(nth : Long) : Long = {
(
primes(
List(filter(2)),
List(2L),
2,
nth
) reverse
) head
}

println(nth_prime(8))
println("The 10001st prime is " + nth_prime(10001))
[/sourcecode]