February 2021 - Branch Wars
Created on February 4, 2021

Introduction
Congratulations Ryan Howard, you've just become the youngest executive in the history of Dunder Mifflin. No one knows how you pulled it off, but regardless, congrats. As your first executive-level task, you've been assigned to figure out the profit-per-employee of different branches and decide which one should be shut down. Being the super-genius-prodigy that you are, you decide to write a purely functional computer program that will compute the answers for you. No wonder you're an executive!
To complete this challenge you will be given a list of branches. From there you will determine the profit per employee of the branch and output the branch which makes the least amount of profit-per-employee so it can be shut down.
The model for a branch looks like this:
final case class EmployeeInfo(name: String, salesRevenue: Int, salary: Int)
sealed abstract class Employee extends Product with Serializable
object Employee {
final case class Manager(info: EmployeeInfo, directReports: List[Employee]) extends Employee
final case class IndividualContributor(info: EmployeeInfo) extends Employee
}
final case class BranchName(value: String) extends AnyVal
final case class Branch(name: BranchName, manager: Employee.Manager)
Each branch contains a manager where a manager contains 0 or more employees. Those employees may be either managers or individual contributors. For examples of what branches look like, see the ChallengeSpec
in the source code.
Getting Started
- Clone the starter code here. The starter code contains acceptance tests. Get the
ChallengeSpec
to pass and you are done! - Run the tests to ensure you are set up (using
sbt test
, for example). All tests will be failing for now.
Note that the starter code contains two main elements: the challenge and the fundamentals. The challenge is the problem described above and the fundamentals are simpler, shorter problems you can solve to brush up on your knowledge prior to completing the full challenge. Feel free to skip the fundamentals if you don't feel like you want/need to do them. Below you will find a walk-through for how to solve all the fundamentals as well as the main challenge.
Tip: It may be helpful as you are going through to run only the tests for either the fundamentals or the challenge. For example, if you only want to run the fundamental tests, you can do so with sbt 'testOnly **.FundamentalsSpec'
.
Fundamentals
Everything after this point is going through solutions to the fundamentals and the main challenge. Stop here if you wish to solve them on your own rather than following this walk-through!
If you want to follow along with the code as we go, you can find the solution repo here.
Fold All the Things
If you've been using Scala for a while, you may already be familiar with either the fold
, foldLeft
, or foldRight
operation. These operations are extremely powerful and more widely applicable than one would immediately think. When I first started using Scala, I didn't know you could use fold on structures other than List
s, but you can! The goal of these fundamentals is to introduce you to using fold
operations to work with List
s and other structures as well.
Fundamental One
Sum the list of integers l
.
def one(l: List[Int]): Int = l.fold(0)(_ + _)
This is probably the simplest and most canonical example of a fold operation. Here we are summing a list of numbers using the fold
operator. The 0
represents the starting point of the fold. If the list l
was empty, this fold
would return 0
. From there we are telling the fold
operation to take each element and add it to the sum of all the numbers that came before it and 0
. For example, if we input List(1, 2, 3)
then the following would happen:
- 0 + 1 = 1
- 1 + 2 = 3
- 3 + 3 = 6
Note: these operations could actually be executed in a different order than the one I listed above. If we needed them to be done strictly in order, we would need to use a foldLeft
or a foldRight
.
Also, don't let the (_ + _)
syntax scare you. This is just shorthand for:
{ (sumOfPreviousNumbers, latestNumber) =>
sumOfPreviousNumbers + latestNumber
}
Fundamental Two
Concatenate the list of char
s into a String
.
def two(l: List[Char]): String = l.foldLeft("")(_ + _)
This is very similar to the last example, except this time we are starting with the empty String
and then using +
to concatenate new String
s on as they come. We are also using foldLeft
instead of fold
. We need to do this because fold
can only be used when the type found inside of your List
is the same as the type you are returning. For example List("hello", "world").fold("")(_ + _)
works because the List
contains String
s and the return type is also a String
. You can see this in concrete terms by looking at the definitions of fold
and foldLeft
:
def fold[A1 >: A](z: A1)(op: (A1, A1) => A1): A1
def foldLeft[B](z: B)(op: (B, A) => B): B
Fundamental Three
Stringify the optional input.
- If the option is None, return "None" as a string.
- If the option is defined, return the value it contains as a string wrapped in
Some(...)
def three(l: Option[String]): String = l.fold("None")(s => s"Some($s)")
Of course, you could easily implement this function using a match statement, but this gives a good intuition for what a fold
is doing. You provide what should be done in the None
(empty) case and the Some
(non-empty) case and the fold
function takes care of the rest.
Fundamental Four
Implement the fold operation for JobStatus.
sealed abstract class JobStatus extends Product with Serializable {
def fold[A](stopped: => A)(running: JobStatus.Running => A): A = this match {
case r: Running => running(r)
case Stopped => stopped
}
}
object JobStatus {
final case class Running(startedAt: Instant) extends JobStatus
case object Stopped extends JobStatus
}
When implementing a fold
operation for an algebraic data type (ADT) such as JobStatus
, it is simplest to have one parameter in the fold
for each possibility in the ADT. In this case, JobStatus
can be either Stopped
or Running
. For this reason, we add (stopped: => A)
for stopped and (running: JobStatus.Running => A)
for running. Notice that the stopped parameter is passed by name so that we don't have to execute this code unless the JobStatus
is in fact Stopped
.
Now let's use the fold
operation we implemented to return "Stopped"
for JobStatus.Stopped
job and "Started at $startedAt"
for JobStatus.Running
.
def four(l: JobStatus): String = l.fold("Stopped")(r => s"Started at ${r.startedAt}")
This implementation is very similar to the one we did for Option
above. We provide what to do when a job is Stopped
and what to do when a job is Running
and the fold
takes care of the rest.
Fundamental Five
Return the length of the input list l
.
def five[A](l: List[A]): Int = l.foldLeft(0)((a, _) => a + 1)
Here we are able to compute the length of the input list by starting with 0
and then just adding one for each new list item we encounter.
Fundamental Six
Implement the contains function where true is returned if i
is contained inside of l
and otherwise false is returned.
def six[A](l: List[A], i: A): Boolean = {
l.foldLeft(false)((a, c) => (c == i) || a)
}
At first, it feels unlikely that we would be able to implement something like a contains function using a fold
operation, but it is actually quite easy. We start off with false
since if we don't find any matches or the List
is empty then we want to return false
. From there all we have to do is check whether or not the current item is equal to the item we are searching for and OR (||
) that with the result of the items we have checked so far.
Fundamental Seven
Reverse the input list l
.
def seven[A](l: List[A]): List[A] = {
l.foldLeft(List.empty[A])((acc, i) => i :: acc)
}
Here we are relying on the fact that we concatenate the new items onto the front of the list (rather than the end) using a foldLeft
.
Fundamental Eight
Transform the input list l
into the equivalent MyList
.
sealed abstract class MyList[+A] extends Product with Serializable
object MyList {
case object Empty extends MyList[Nothing]
final case class Cons[A](h: A, t: MyList[A]) extends MyList[A]
}
def eight[A](l: List[A]): MyList[A] = {
l.foldRight(MyList.Empty: MyList[A])((a, b) => MyList.Cons(a, b))
}
A foldRight
is preferable here since it allows us to copy the list over exactly without needing to reverse it in the process. This seems counter-intuitive at first, but if you look at the structure of a List
(or MyList
), then it starts to make more sense. For example, the list List(1, 2, 3, 4)
can also be written as a MyList
as:
Cons(1, Cons(2, Cons(3, Cons(4, Empty))))
Notice that the last element in the MyList
is nested into this Cons
structure the furthest. This tells us that we need to start with the last element in order to build this up, rather than starting with the first element. This is why we choose foldRight
here instead of foldLeft
. Using foldLeft
would give us Cons(4, Cons(3, Cons(2, Cons(1, Empty))))
which is the reverse of what we wanted.
Fundamental Nine: foldLeft
Implement a tail-recursive foldLeft function for the List
type.
@tailrec
def foldLeft[A, B](l: List[A])(base: B)(f: (B, A) => B): B = l match {
case Nil => base
case head :: tail => foldLeft(tail)(f(base, head))(f)
}
Here we have an input List
l
, a base case base
, and a function f
that will be applied to each element of l
to incrementally build up the B
that is returned. If you aren't used to working with generic types, this function may look a little intimidating. So let's break it down.
As we explored previously, a List
is composed of a series of items (built up using Cons or ::
), followed by Nil
indicating the end of the list. For this reason, we begin our function by matching on the list and telling it what to do in each of these cases. In the case of Nil
, we just want to return the base case parameter. This is intuitive since that parameter is telling us what to do when we encounter an empty list. The other option is where we match on the ::
part of a list and extract out the head
and tail
of the list. Here the head is the first element and the tail is the remainder of the list. From there, we recursively call foldLeft
again, this time passing the tail
in for l
because we want to move onto the next element in the list. The next part of the function, f(base, head)
, is probably the trickiest. Here we are calling the function f
with the base case as the input for B
and the head
of l
for the A
. The reason for this is that as we recursively call foldLeft
, the base will be equal to the result of the items that we have evaluated so far. Calling f(base, head)
is taking the result of evaluating all prior elements (base
) and applying it with f
to the latest element (head
). The last argument to foldLeft
is just passing f
along so it can be used for the remaining elements.
Fundamental Ten: foldRight
Implement the foldRight
function for the List
type. Do NOT use the reverse
operation on List
as part of your implementation.
def foldRight[A, B](l: List[A])(base: B)(f: (A, B) => B): B = l match {
case Nil => base
case head :: tail => f(head, foldRight(tail)(base)(f))
}
Note that this implementation is not tail recursive. It is certainly possible to make foldRight
stack safe, but for the sake of this fundamental, we will not worry about this. You will notice that the implementation of foldRight
is very similar to the implementation of foldLeft
. The main difference is that the location of the recursive call to foldRight
and the call to the function f
have switched places. This essentially gives us the behavior of recursing all the way down to the base case and then building up our result by calling f
on each pair of head
and the result of the recursive call as we come back up the stack.
Challenge
With the fundamentals in mind, we are now ready to approach solving this month's challenge. There are multiple ways that we can solve this challenge, but here we are going to focus on solving it using a fold function.
To start out, here is what we are working with for the start of the challenge:
final case class EmployeeInfo(name: String, salesRevenue: Int, salary: Int)
sealed abstract class Employee extends Product with Serializable
object Employee {
final case class Manager(info: EmployeeInfo, directReports: List[Employee]) extends Employee
final case class IndividualContributor(info: EmployeeInfo) extends Employee
}
final case class BranchName(value: String) extends AnyVal
final case class Branch(name: BranchName, manager: Employee.Manager)
def determineBranchToShutDown(branches: NonEmptyList[Branch]): BranchName = ???
Our task is to implement the determineBranchToShutDown
function. Given a list of Branch
we need to return the name of the branch that has the worst profit-per-employee.
Implementing Employee Fold
To start off, we will implement a fold
function for the Employee
type that we can use in our algorithm. The first step in implementing this is to come up with the signature of the fold operation.
sealed abstract class Employee extends Product with Serializable {
def fold[A](ic: EmployeeInfo => A)(m: (EmployeeInfo, List[A]) => A): A = ???
}
Here you can see that we are passing two parameters to our fold
function. The first (ic
) is telling the fold
what to do in the case of an individual contributor, and the second (m
) is telling it what to do when encountering a manager. The parameters from each of these functions is derived by looking at the constructor for the types IndividualContributor
and Manager
respectively. However, notice that in the function m
, we have completely replaced the reference to Employee
with a generic type A
. This is because we want to be able to transform Employee
s into some type A
and combine those to get our final A
. This will be easier to see as we implement the body of the function.
sealed abstract class Employee extends Product with Serializable {
def fold[A](ic: EmployeeInfo => A)(m: (EmployeeInfo, List[A]) => A): A = this match {
case IndividualContributor(info) => ic(info)
case Manager(info, directReports) => m(info, directReports.map(_.fold(ic)(m)))
}
}
Here you will notice that the implementation of this method looks quite similar to foldRight
on a List
. Once again, this implementation is not stack-safe, but we will ignore that for now. An IndividualContributor
is similar to Nil
for a List
because it does not contain any sub-objects that need to be folded over. It is effectively a base case. Manager
, on the other hand, contains a List[Employee]
that does need to be recursively folded over. We do this by mapping over the List[Employee]
and calling fold
on each element. This call to fold
will return an A
which goes back to why m
contains a List[A]
instead of a List[Employee]
.
Counting Employees
Now that we have a fold
function implemented, we can easily create methods that will calculate profit-per-employee. The first function that we will implement here is one to get the total number of employees (counting managers) in a Branch
.
sealed abstract class Employee extends Product with Serializable {
...
def numOfEmployees: Int = this.fold(_ => 1)((_, emp) => 1 + emp.sum)
}
This numOfEmployees
function counts 1 for every IndividualContributor
and then counts 1 for each Manager
and adds that to the sum of their Employee
count.
Calculating Profit
The next thing we will need for this algorithm is to calculate the total profit for a branch.
sealed abstract class Employee extends Product with Serializable {
...
def totalProfit: Int = this.fold(i => i.salesRevenue - i.salary)((i, acc) => acc.sum + (i.salesRevenue - i.salary))
}
This function is almost identical to numOfEmployees
except that here we changed what we are doing for each Employee
. Rather than adding 1 for each IndividualContributor
and Manager
, we now add the profit for each Employee
which is calculated with i.salesRevenue - i.salary
.
Simpler Approach
The only issue with doing these two calculations together is that we are going to have to traverse over our entire tree of employees twice. Once to get the number of employees and once to get the total profit. We can do better than that by combining these two calculations into a single fold
as follows:
sealed abstract class Employee extends Product with Serializable {
...
def profitPerEmployee: Double = {
val result = this.fold(i => (i.salesRevenue - i.salary, 1))(
(i, acc) => (acc.map(_._1).sum + (i.salesRevenue - i.salary), acc.map(_._2).sum + 1)
)
result._1.toDouble / result._2
}
}
If you look at this implementation closely, you will see that we are just doing what we did in each of the separate folds before, but all in one fold now. We are adding 1 for each employee to keep a count and also calculating the profit for each employee. We are taking each of these and putting the result into a tuple that we use at the end to get the final answer by dividing the total profit by the total number of employees.
Putting it Together
def determineBranchToShutDown(branches: NonEmptyList[Branch]): BranchName = {
branches.toList.minBy(_.manager.profitPerEmployee).name
}
The implementation of this method is now very simple. We just have to take the minimum of all the branches (by profit-per-employee) and return its name.
Conclusion
You should now be able to run all of the tests in this project and see them passing! If you have encountered any issues along the way, feel free to reach out on Discord or Twitter.
The purpose of this challenge was to show the power of the fold
operation and the variety of ways it can be used. fold
and it's siblings foldLeft
and foldRight
are great ways to write purely functional algorithms. Further, they can greatly simplify the implementations of many recursive algorithms.
Next
Previous