Blog

2022.01.12

Engineering

Exhaustive Union Matching in Python

Tag

Tobias Pfeiffer

Engineer

Pattern matching on algebraic data types is a powerful technique to process a given input and many programming languages have adopted it in one way or another. A check on whether a given match is “exhaustive”, i.e., covers all possible inputs, is helpful to avoid bugs when the set of possible inputs is extended; for example, when new enumeration values are added.

In this blog post I will first briefly explain the concept of sum types, a kind of algebraic data types, and give examples of pattern matching on these types in Rust and Scala. Then I will show how to define sum types in recent Python versions, and I will explain how the mypy type checker can (to a limited degree) be used to add exhaustiveness checks to Python code working with these types.

tl;dr

Whenever you want to deal with all cases of an enum.Enum, a typing.Literal or a typing.Union in an if/elif construct, add a final else block where you pass the value that “can never exist” to a special function:

if x is Color.RED:
  print("red")
elif x is Color.GREEN:
  print("green")
else:
  assert_exhaustiveness(x)

Define this special function called assert_exhaustiveness() as follows:

def assert_exhaustiveness(x: NoReturn) -> NoReturn:
  """Provide an assertion at type-check time that this function is never called."""
  raise AssertionError(f"Invalid value: {x!r}")

If you do so, an unhandled variant will cause a type-check error:

error: Argument 1 to "assert_exhaustiveness" has incompatible type "Literal[Color.BLUE]"; expected "NoReturn"  [arg-type]
assert_exhaustiveness(x)
                    ^

Credit for this goes to the participants of the discussion in https://github.com/python/mypy/issues/5818

Sum Types and Pattern Matching

Algebraic data types are an important concept in programming, especially in functional programming. In particular, sum types are helpful to model concepts where a value can be exactly one of a limited number of variants. In its simplest form, bool could be considered a sum type with the two variants true and false. An enumeration such as Weekday is a sum type with constant variants Monday, Tuesday, etc.

Sum types become very powerful when the variants are not only constant but can hold some values as well, for example a type UserInput could have variants KeyPress(key) holding the code of the pressed key, or MouseClick(x, y) holding the coordinates of the click.

Many programming languages allow to define algebraic data types and provide some more or less powerful tools to process them. For example, the Rust programming language allows to define sum types using the enum keyword:

enum UserInput {
    Quit,
    KeyPress(char),
    MouseClick(i32, i32),
}

A variable of type UserInput can then be processed using pattern matching:

match evt {
    UserInput::Quit => println!("done"),
    UserInput::KeyPress(c) => println!("user pressed key '{:}'", c),
    UserInput::MouseClick(x, y) => println!("user clicked at {:},{:}", x, y),
};

In the Scala programming language, sum types can be defined as case classes inheriting from a common trait:

sealed trait UserInput

case class Quit() extends UserInput
case class KeyPress(c: Char) extends UserInput
case class MouseClick(x: Int, y: Int) extends UserInput

Again, pattern matching is used to process objects of type UserInput, with a syntax very similar to Rust:

evt match {
    case Quit() => println("done")
    case KeyPress(c) => println(s"user pressed key $c")
    case MouseClick(x,y) => println(s"user clicked at $x,$y")
}

This blog post is supposed to be mostly about Python, not about Scala or Rust, but let me make a couple of remarks before moving on.

Sum types are so powerful because they are limited to a small number of exactly known variants. A value of type UserInput as defined above can be exactly one of the three defined variants, it cannot be empty and it cannot have any other unexpected type (cf. Parse, don’t validate).

Pattern matching as shown above is a natural way to process sum types. On one hand, you cannot check for things that can obviously not happen: Adding a match line 1..3 => … to the Rust code above yields a compile-time error:

11 |     match evt {
   |           --- this match expression has type `UserInput`
...
15 |         1..3 => print!("abc"),
   |         ^^^^ expected enum `UserInput`, found integer

On the other hand, if the language supports it, pattern matches can be checked by the compiler to be exhaustive, i.e., they cover all the possible variants. Removing one of the match lines from the Rust code above yields a compile-time error:

2  | / enum UserInput {
3  | |     Quit,
4  | |     KeyPress(char),
5  | |     MouseClick(i32, i32),
   | |     ---------- not covered
6  | | }
   | |_- `UserInput` defined here
...
11 |       match evt {
   |             ^^^ pattern `MouseClick(_, _)` not covered

In the case of Scala, the compilation succeeds but a warning is printed by the compiler:

warning: match may not be exhaustive.
It would fail on the following input: MouseClick(_, _)
evt match {
^

Note that for Scala this only works if we declare the UserInput trait as sealed, which means that all subclasses must be defined in the same file as the trait itself. Otherwise there could be other subclasses of UserInput and there would be no way of ensuring that we have covered all of the possible subclasses.

Sum Types in Python

As someone who works mostly in Python but also enjoys working with static type systems, I generally use Python with type hints and have my code checked by mypy, and tend to try and reproduce, say, idiomatic Scala in Python. Maybe occasionally a bit too hard, some might say.

Perhaps unsurprisingly given its dynamic typing, historically there has been no support to define sum types in Python; even the enum module was only added in Python 3.4. Python 3.5 introduced type hints, where the corresponding PEP 484 defines the Union type. Python 3.8 brought Literal types, defined in PEP 586. These are the three methods to define sum types in Python, where the Union type is the only one that allows to define variants holding values:

@dataclass
class Quit:
    pass

@dataclass
class KeyPress:
    c: str

@dataclass
class MouseClick:
    x: int
    y: int

UserInput = Union[Quit, KeyPress, MouseClick]

Note that we could make UserInput an empty abstract class and then define the three variants as subclasses, but the absence of a sealed annotation (which was actually proposed in the now-superseded PEP 622) means that there could always be yet another subclass of UserInput floating around that we don’t know about.

Before Python 3.10, there was no native pattern matching in Python, and while there is a syntax for pattern matching by now, at the time of writing this post mypy does not yet support the match syntax. So assuming that we are in Python <= 3.9, if a variable is annotated as UserInput and mypy tells us we can be sure that is is exactly one of the three defined variants, the way to match the actual type is the if/elif construct:

if isinstance(evt, Quit):
    print("done")
elif isinstance(evt, KeyPress):
    print(f"user pressed key {evt.c}")
elif isinstance(evt, MouseClick):
    print(f"user clicked at {evt.x},{evt.y}")

mypy checks that in each branch we use the detected variant correctly, for example when writing print(f"user pressed key {evt.x}") in the elif isinstance(evt, KeyPress) block, it would complain:

error: "KeyPress" has no attribute "x"

Also, when we try to add a branch that cannot possibly match, such as elif isinstance(evt, int), then mypy fails as well, although it needs the --warn-unreachable flag to do so:

error: Statement is unreachable

However, as the code above looks, there is no way to find out if we may have forgotten to cover a variant. In other words, the match is not exhaustive.

Exhaustive Union Matching

The Python code above does not have an else clause. However, it works just as well if we replace the last isinstance() check with an else:

if isinstance(evt, Quit):
    print("done")
elif isinstance(evt, KeyPress):
    print(f"user pressed key {evt.c}")
else:
    print(f"user clicked at {evt.x},{evt.y}")

mypy understands the Union type well enough to know that after ruling out Quit and KeyPress, the only thing that is possibly left is an instance of MouseClick. In fact, while it is easy to overread, the section called “Support for singleton types in unions” of PEP 484 already mentions that a type checker should infer the only remaining type of Union or Enum values in an else block. The PEP 586 section called “Interactions with enums and exhaustiveness checks” gets more explicit, it states: “Type checkers should be capable of performing exhaustiveness checks when working [with] Literal types that have a closed number of variants, such as enums.”

Unfortunately, using else to branch into the block to process the only remaining Union variant is also not without its problems. Assume we add another variant

@dataclass
class ScreenTouch:
    x: int
    y: int
    pressure: float

to the UserInput Union. This, too, would make mypy pass (all elements of the narrowed-down Union[MouseClick, ScreenTouch] have x and y members), but it would probably not be what you want. While the case of adding another variant that happens to accidentally look just as the variant you are treating in the else block may not be too common, think of the case of adding another variant to an existing Enum as a less contrived example.

One way to resolve this problem was brought up by GitHub user @bluetech in https://github.com/python/mypy/issues/5818 inspired by the TypeScript approach of exhaustive checking, has since been described in the mypy Literal documentation as well, and is based on the following approach:

  1. For a value x, handle each possible variant of the corresponding sum type in an elif block.
  2. Add an else block. If indeed all variants were covered before, the only remaining possible type for x is the empty Union.
  3. In the else block, pass x to a function that does not accept a value of any type.

The function from (3) may look as follows:

from typing import NoReturn

def assert_exhaustiveness(x: NoReturn) -> NoReturn:
  """Provide an assertion at type-check time that this function is never called."""
  raise AssertionError(f"Invalid value: {x!r}")

No existing value can have the type NoReturn, so whatever possible variant is left in the else block causes the mypy check to fail. Note that what causes this failure is the signature of the function, the body could be a simple pass as well. However, raising an exception in the body is helpful to get an error at runtime, for example, if mypy was not run or if some value was manually cast into an incorrect type.

If we rewrite the code at the top of this section into

if isinstance(evt, Quit):
    print("done")
elif isinstance(evt, KeyPress):
    print(f"user pressed key {evt.c}")
elif isinstance(evt, MouseClick):
    print(f"user clicked at {evt.x},{evt.y}")
else:
    assert_exhaustiveness(evt)

where UserInput = Union[Quit, KeyPress, MouseClick], then it will pass a mypy check. However, if we add the ScreenTouch class to the Union as well, then the same code will make mypy fail with:

error: Argument 1 to "assert_exhaustiveness" has incompatible type "ScreenTouch"; expected "NoReturn"

mypy inferred that the only possible type that is left for evt in the else block is ScreenTouch, but we cannot pass ScreenTouch to a function that accepts only NoReturn.

Similarly, the following code will make mypy fail:

class Color(enum.Enum):
  RED = enum.auto()
  GREEN = enum.auto()
  BLUE = enum.auto()

def enum_check(x: Color) -> None:
    if x is Color.RED:
        print("red")
    elif x is Color.GREEN:
        print("green")
    else:
        assert_exhaustiveness(x)

def literal_check(x: Literal["yes", "no", "maybe"]) -> None:
    if x == "yes":
        print("yes")
    elif x == "no":
        print("no")
    else:
        assert_exhaustiveness(x)

The error messages are

error: Argument 1 to "assert_exhaustiveness" has incompatible type "Literal[Color.BLUE]"; expected "NoReturn"

and

Argument 1 to "assert_exhaustiveness" has incompatible type "Literal['maybe']"; expected "NoReturn"

respectively. Internally mypy translates the Enum into a Literal[Color.RED, Color.GREEN, Color.BLUE]. However, note that unlike a check for a literal value, the check for an Enum variant inside an if statement must be done using is, it does not work with ==.

Digression: Use in Functional-Style Python

In the previous section I started from the example code

if isinstance(evt, Quit):
    print("done")
elif isinstance(evt, KeyPress):
    print(f"user pressed key {evt.c}")
elif isinstance(evt, MouseClick):
    print(f"user clicked at {evt.x},{evt.y}")

and argued that this should be extended by an else block with a special function call to guarantee the exhaustiveness. However, if you find that additional block too verbose or feel that we shouldn’t add dead code to get these checks, there is another way we can make use of mypy’s type narrowing, without the else block.

If the code above is made the final part of a stand-alone function, and each of the if/elif blocks returns a value rather than just calling some other code, then the exhaustiveness can also be guaranteed by mypy’s check of whether all code paths return the correct type. Consider the following function:

def event_desc(evt: UserInput) -> str:
    if isinstance(evt, Quit):
        return "done"
    if isinstance(evt, KeyPress):
        return f"user pressed key {evt.c}"
    if isinstance(evt, MouseClick):
        return f"user clicked at {evt.x},{evt.y}"

This function passes the mypy check if UserInput is Union[Quit, KeyPress, MouseClick], but if ScreenTouch is added to the Union then it fails the mypy check with:

error: Missing return statement

This functional style avoids the else block with the assert_exhaustiveness() call and it may be suitable for cases where each if/elif block is run not only for its side-effects. However, the error message is less descriptive, and there is a risk of accidentally losing the exhaustiveness checks when the code is refactored to run inside another block, rather than as a function of its own.

Benefits

Adding an exhaustiveness check as shown above can provide a lot of safety in situations where sum types are likely to be extended by additional variants in the future. This may not be the case for a Color enumeration, but certainly for, say, a UserEvent, AuthMethod, or LogLevel union. We can save some effort for unit tests to see whether all code paths are covered by having mypy check it for us before the code is ever executed.

As a corollary, it becomes much more attractive to use enumerations or Literal types to parameterize behavior, rather than, say, plain strings. As an example, consider the mode parameter of the open() function: This parameter is annotated as a Literal, so on the calling side a user can have mypy check if the function is used correctly, but also on the implementing side there is some additional safety that all possible parameter values have been covered. This is a significant improvement over using a plain string.

Finally, in some cases it can be helpful for IDE features like “Find Usages” (or grep searches) if each variant is explicitly named in an if/elif statement rather than covered in an else block.

Limitations

The technique described in the previous section works purely on type narrowing. There is no other logic, such as narrowing down the possible value range for variables etc. At the time of writing, even simple additional conditions break the exhaustiveness checker. Consider the following code snippet:

def is_check_exhaustive_with_condition(x: Color, i: int) -> None:
    if x is Color.RED:
        print("red")
    elif x is Color.GREEN:
        print("green")
    elif x is Color.BLUE and i < 0:
        print("blue/neg")
    elif x is Color.BLUE and i >= 0:
        print("blue/nonneg")
    else:
        assert_exhaustiveness(x)

It is easy to see that the else block is unreachable, still mypy fails with

error: Argument 1 to "assert_exhaustiveness" has incompatible type "Literal[Color.BLUE]"; expected "NoReturn"

However, if you manage to express your condition as a type narrowing problem (not saying that this is always a good idea) then you may be able to work around some limitations. For example, the following code snippet is equivalent to the one above, but passes a mypy check:

def is_check_exhaustive_with_conditions(x: Color, i: int) -> None:
    i_type: Literal["neg", "nonneg"] = "neg" if i < 0 else "nonneg"
    if x is Color.RED:
        print("red")
    elif x is Color.GREEN:
        print("green")
    elif x is Color.BLUE and i_type == "neg":
        print("blue/neg")
    elif x is Color.BLUE and i_type == "nonneg":
        print("blue/nonneg")
    else:
        assert_exhaustiveness(x)

Conclusion

In this blog post I have presented the various ways to define sum types in recent Python versions, and shown how you can make use of mypy features to achieve exhaustive matching similar to what many other languages provide. I am sure that in the eyes of many programmers, adding a dummy else block all over the code base that is never reached at runtime is not considered “elegant”. However, I hope I could demonstrate that it can improve safety and reduce the need for unit tests where sum types are used, in turn making it more attractive to use such types in Python code.

Discuss this post on Hacker News: https://news.ycombinator.com/item?id=29897462

Tag

  • Twitter
  • Facebook