February 2021 - Branch Wars

Created on February 4, 2021

Jim, Dwight, and Michael sitting in Karen's office.
"The Office" Episode "Branch Wars"

Introduction

Congratulations Ryan Howard, you've just become the youngest executive in the history of Dunder Mifflin. No one knows how you pulled it off, but regardless, congrats. As your first executive-level task, you've been assigned to figure out the profit-per-employee of different branches and decide which one should be shut down. Being the super-genius-prodigy that you are, you decide to write a purely functional computer program that will compute the answers for you. No wonder you're an executive!

To complete this challenge you will be given a list of branches. From there you will determine the profit per employee of the branch and output the branch which makes the least amount of profit-per-employee so it can be shut down.

The model for a branch looks like this:

final case class EmployeeInfo(name: String, salesRevenue: Int, salary: Int)
sealed abstract class Employee extends Product with Serializable
object Employee {
final case class Manager(info: EmployeeInfo, directReports: List[Employee]) extends Employee
final case class IndividualContributor(info: EmployeeInfo) extends Employee
}
final case class BranchName(value: String) extends AnyVal
final case class Branch(name: BranchName, manager: Employee.Manager)

Each branch contains a manager where a manager contains 0 or more employees. Those employees may be either managers or individual contributors. For examples of what branches look like, see the ChallengeSpec in the source code.

Getting Started

  1. Clone the starter code here. The starter code contains acceptance tests. Get the ChallengeSpec to pass and you are done!
  2. Run the tests to ensure you are set up (using sbt test, for example). All tests will be failing for now.

Note that the starter code contains two main elements: the challenge and the fundamentals. The challenge is the problem described above and the fundamentals are simpler, shorter problems you can solve to brush up on your knowledge prior to completing the full challenge. Feel free to skip the fundamentals if you don't feel like you want/need to do them. Below you will find a walk-through for how to solve all the fundamentals as well as the main challenge.

Tip: It may be helpful as you are going through to run only the tests for either the fundamentals or the challenge. For example, if you only want to run the fundamental tests, you can do so with sbt 'testOnly **.FundamentalsSpec'.

Fundamentals

Everything after this point is going through solutions to the fundamentals and the main challenge. Stop here if you wish to solve them on your own rather than following this walk-through!

If you want to follow along with the code as we go, you can find the solution repo here.

Fold All the Things

If you've been using Scala for a while, you may already be familiar with either the fold, foldLeft, or foldRight operation. These operations are extremely powerful and more widely applicable than one would immediately think. When I first started using Scala, I didn't know you could use fold on structures other than Lists, but you can! The goal of these fundamentals is to introduce you to using fold operations to work with Lists and other structures as well.

Fundamental One

Sum the list of integers l.

def one(l: List[Int]): Int = l.fold(0)(_ + _)

This is probably the simplest and most canonical example of a fold operation. Here we are summing a list of numbers using the fold operator. The 0 represents the starting point of the fold. If the list l was empty, this fold would return 0. From there we are telling the fold operation to take each element and add it to the sum of all the numbers that came before it and 0. For example, if we input List(1, 2, 3) then the following would happen:

  1. 0 + 1 = 1
  2. 1 + 2 = 3
  3. 3 + 3 = 6

Note: these operations could actually be executed in a different order than the one I listed above. If we needed them to be done strictly in order, we would need to use a foldLeft or a foldRight.

Also, don't let the (_ + _) syntax scare you. This is just shorthand for:

{ (sumOfPreviousNumbers, latestNumber) =>
sumOfPreviousNumbers + latestNumber
}

Fundamental Two

Concatenate the list of chars into a String.

def two(l: List[Char]): String = l.foldLeft("")(_ + _)

This is very similar to the last example, except this time we are starting with the empty String and then using + to concatenate new Strings on as they come. We are also using foldLeft instead of fold. We need to do this because fold can only be used when the type found inside of your List is the same as the type you are returning. For example List("hello", "world").fold("")(_ + _) works because the List contains Strings and the return type is also a String. You can see this in concrete terms by looking at the definitions of fold and foldLeft:

def fold[A1 >: A](z: A1)(op: (A1, A1) => A1): A1
def foldLeft[B](z: B)(op: (B, A) => B): B

Fundamental Three

Stringify the optional input.

  • If the option is None, return "None" as a string.
  • If the option is defined, return the value it contains as a string wrapped in Some(...)
def three(l: Option[String]): String = l.fold("None")(s => s"Some($s)")

Of course, you could easily implement this function using a match statement, but this gives a good intuition for what a fold is doing. You provide what should be done in the None (empty) case and the Some (non-empty) case and the fold function takes care of the rest.

Fundamental Four

Implement the fold operation for JobStatus.

sealed abstract class JobStatus extends Product with Serializable {
def fold[A](stopped: => A)(running: JobStatus.Running => A): A = this match {
case r: Running => running(r)
case Stopped => stopped
}
}
object JobStatus {
final case class Running(startedAt: Instant) extends JobStatus
case object Stopped extends JobStatus
}

When implementing a fold operation for an algebraic data type (ADT) such as JobStatus, it is simplest to have one parameter in the fold for each possibility in the ADT. In this case, JobStatus can be either Stopped or Running. For this reason, we add (stopped: => A) for stopped and (running: JobStatus.Running => A) for running. Notice that the stopped parameter is passed by name so that we don't have to execute this code unless the JobStatus is in fact Stopped.

Now let's use the fold operation we implemented to return "Stopped" for JobStatus.Stopped job and "Started at $startedAt" for JobStatus.Running.

def four(l: JobStatus): String = l.fold("Stopped")(r => s"Started at ${r.startedAt}")

This implementation is very similar to the one we did for Option above. We provide what to do when a job is Stopped and what to do when a job is Running and the fold takes care of the rest.

Fundamental Five

Return the length of the input list l.

def five[A](l: List[A]): Int = l.foldLeft(0)((a, _) => a + 1)

Here we are able to compute the length of the input list by starting with 0 and then just adding one for each new list item we encounter.

Fundamental Six

Implement the contains function where true is returned if i is contained inside of l and otherwise false is returned.

def six[A](l: List[A], i: A): Boolean = {
l.foldLeft(false)((a, c) => (c == i) || a)
}

At first, it feels unlikely that we would be able to implement something like a contains function using a fold operation, but it is actually quite easy. We start off with false since if we don't find any matches or the List is empty then we want to return false. From there all we have to do is check whether or not the current item is equal to the item we are searching for and OR (||) that with the result of the items we have checked so far.

Fundamental Seven

Reverse the input list l.

def seven[A](l: List[A]): List[A] = {
l.foldLeft(List.empty[A])((acc, i) => i :: acc)
}

Here we are relying on the fact that we concatenate the new items onto the front of the list (rather than the end) using a foldLeft.

Fundamental Eight

Transform the input list l into the equivalent MyList.

sealed abstract class MyList[+A] extends Product with Serializable
object MyList {
case object Empty extends MyList[Nothing]
final case class Cons[A](h: A, t: MyList[A]) extends MyList[A]
}

def eight[A](l: List[A]): MyList[A] = {
l.foldRight(MyList.Empty: MyList[A])((a, b) => MyList.Cons(a, b))
}

A foldRight is preferable here since it allows us to copy the list over exactly without needing to reverse it in the process. This seems counter-intuitive at first, but if you look at the structure of a List (or MyList), then it starts to make more sense. For example, the list List(1, 2, 3, 4) can also be written as a MyList as:

Cons(1, Cons(2, Cons(3, Cons(4, Empty))))

Notice that the last element in the MyList is nested into this Cons structure the furthest. This tells us that we need to start with the last element in order to build this up, rather than starting with the first element. This is why we choose foldRight here instead of foldLeft. Using foldLeft would give us Cons(4, Cons(3, Cons(2, Cons(1, Empty)))) which is the reverse of what we wanted.

Fundamental Nine: foldLeft

Implement a tail-recursive foldLeft function for the List type.

@tailrec
def foldLeft[A, B](l: List[A])(base: B)(f: (B, A) => B): B = l match {
case Nil => base
case head :: tail => foldLeft(tail)(f(base, head))(f)
}

Here we have an input List l, a base case base, and a function f that will be applied to each element of l to incrementally build up the B that is returned. If you aren't used to working with generic types, this function may look a little intimidating. So let's break it down.

As we explored previously, a List is composed of a series of items (built up using Cons or ::), followed by Nil indicating the end of the list. For this reason, we begin our function by matching on the list and telling it what to do in each of these cases. In the case of Nil, we just want to return the base case parameter. This is intuitive since that parameter is telling us what to do when we encounter an empty list. The other option is where we match on the :: part of a list and extract out the head and tail of the list. Here the head is the first element and the tail is the remainder of the list. From there, we recursively call foldLeft again, this time passing the tail in for l because we want to move onto the next element in the list. The next part of the function, f(base, head), is probably the trickiest. Here we are calling the function f with the base case as the input for B and the head of l for the A. The reason for this is that as we recursively call foldLeft, the base will be equal to the result of the items that we have evaluated so far. Calling f(base, head) is taking the result of evaluating all prior elements (base) and applying it with f to the latest element (head). The last argument to foldLeft is just passing f along so it can be used for the remaining elements.

Fundamental Ten: foldRight

Implement the foldRight function for the List type. Do NOT use the reverse operation on List as part of your implementation.

def foldRight[A, B](l: List[A])(base: B)(f: (A, B) => B): B = l match {
case Nil => base
case head :: tail => f(head, foldRight(tail)(base)(f))
}

Note that this implementation is not tail recursive. It is certainly possible to make foldRight stack safe, but for the sake of this fundamental, we will not worry about this. You will notice that the implementation of foldRight is very similar to the implementation of foldLeft. The main difference is that the location of the recursive call to foldRight and the call to the function f have switched places. This essentially gives us the behavior of recursing all the way down to the base case and then building up our result by calling f on each pair of head and the result of the recursive call as we come back up the stack.

Challenge

With the fundamentals in mind, we are now ready to approach solving this month's challenge. There are multiple ways that we can solve this challenge, but here we are going to focus on solving it using a fold function.

To start out, here is what we are working with for the start of the challenge:

final case class EmployeeInfo(name: String, salesRevenue: Int, salary: Int)
sealed abstract class Employee extends Product with Serializable
object Employee {
final case class Manager(info: EmployeeInfo, directReports: List[Employee]) extends Employee
final case class IndividualContributor(info: EmployeeInfo) extends Employee
}
final case class BranchName(value: String) extends AnyVal
final case class Branch(name: BranchName, manager: Employee.Manager)

def determineBranchToShutDown(branches: NonEmptyList[Branch]): BranchName = ???

Our task is to implement the determineBranchToShutDown function. Given a list of Branch we need to return the name of the branch that has the worst profit-per-employee.

Implementing Employee Fold

To start off, we will implement a fold function for the Employee type that we can use in our algorithm. The first step in implementing this is to come up with the signature of the fold operation.

sealed abstract class Employee extends Product with Serializable {
def fold[A](ic: EmployeeInfo => A)(m: (EmployeeInfo, List[A]) => A): A = ???
}

Here you can see that we are passing two parameters to our fold function. The first (ic) is telling the fold what to do in the case of an individual contributor, and the second (m) is telling it what to do when encountering a manager. The parameters from each of these functions is derived by looking at the constructor for the types IndividualContributor and Manager respectively. However, notice that in the function m, we have completely replaced the reference to Employee with a generic type A. This is because we want to be able to transform Employees into some type A and combine those to get our final A. This will be easier to see as we implement the body of the function.

sealed abstract class Employee extends Product with Serializable {
def fold[A](ic: EmployeeInfo => A)(m: (EmployeeInfo, List[A]) => A): A = this match {
case IndividualContributor(info) => ic(info)
case Manager(info, directReports) => m(info, directReports.map(_.fold(ic)(m)))
}
}

Here you will notice that the implementation of this method looks quite similar to foldRight on a List. Once again, this implementation is not stack-safe, but we will ignore that for now. An IndividualContributor is similar to Nil for a List because it does not contain any sub-objects that need to be folded over. It is effectively a base case. Manager, on the other hand, contains a List[Employee] that does need to be recursively folded over. We do this by mapping over the List[Employee] and calling fold on each element. This call to fold will return an A which goes back to why m contains a List[A] instead of a List[Employee].

Counting Employees

Now that we have a fold function implemented, we can easily create methods that will calculate profit-per-employee. The first function that we will implement here is one to get the total number of employees (counting managers) in a Branch.

sealed abstract class Employee extends Product with Serializable {
...
def numOfEmployees: Int = this.fold(_ => 1)((_, emp) => 1 + emp.sum)
}

This numOfEmployees function counts 1 for every IndividualContributor and then counts 1 for each Manager and adds that to the sum of their Employee count.

Calculating Profit

The next thing we will need for this algorithm is to calculate the total profit for a branch.

sealed abstract class Employee extends Product with Serializable {
...
def totalProfit: Int = this.fold(i => i.salesRevenue - i.salary)((i, acc) => acc.sum + (i.salesRevenue - i.salary))
}

This function is almost identical to numOfEmployees except that here we changed what we are doing for each Employee. Rather than adding 1 for each IndividualContributor and Manager, we now add the profit for each Employee which is calculated with i.salesRevenue - i.salary.

Simpler Approach

The only issue with doing these two calculations together is that we are going to have to traverse over our entire tree of employees twice. Once to get the number of employees and once to get the total profit. We can do better than that by combining these two calculations into a single fold as follows:

sealed abstract class Employee extends Product with Serializable {
...
def profitPerEmployee: Double = {
val result = this.fold(i => (i.salesRevenue - i.salary, 1))(
(i, acc) => (acc.map(_._1).sum + (i.salesRevenue - i.salary), acc.map(_._2).sum + 1)
)
result._1.toDouble / result._2
}
}

If you look at this implementation closely, you will see that we are just doing what we did in each of the separate folds before, but all in one fold now. We are adding 1 for each employee to keep a count and also calculating the profit for each employee. We are taking each of these and putting the result into a tuple that we use at the end to get the final answer by dividing the total profit by the total number of employees.

Putting it Together

def determineBranchToShutDown(branches: NonEmptyList[Branch]): BranchName = {
branches.toList.minBy(_.manager.profitPerEmployee).name
}

The implementation of this method is now very simple. We just have to take the minimum of all the branches (by profit-per-employee) and return its name.

Conclusion

You should now be able to run all of the tests in this project and see them passing! If you have encountered any issues along the way, feel free to reach out on Discord or Twitter.

The purpose of this challenge was to show the power of the fold operation and the variety of ways it can be used. fold and it's siblings foldLeft and foldRight are great ways to write purely functional algorithms. Further, they can greatly simplify the implementations of many recursive algorithms.

Next

March 2021 - Lensception

Previous

January 2021 - The Parser's Gambit

By using this site, you agree that you have read and understand its Privacy Policy.