Why do we need itertools?

The idea behind itertools is to deal with large amounts of data (typically sequence data sets) in a memory efficient way. According to the docs [1]:

This module implements a number of iterator building blocks inspired by constructs from APL, Haskell, and SML. Each has been recast in a form suitable for Python.

The module standardizes a core set of fast, memory efficient tools that are useful by themselves or in combination. Together, they form an “iterator algebra” making it possible to construct specialized tools succinctly and efficiently in pure Python.

All functions included in the itertools module construct and produce iterators. Iterators are implementations of the iterator protocol meaning that large data sets can be consumed “lazily”. In other words, the whole data set does not need to be in memory at once during processing. Instead, each element is consumed and processed separately. This eliminates common side effects of large data sets (such as swapping) leading to improved performance. If you want to know more about iterators, I can highly recommend two articles from Trey Hunner (see [2] and [3]).

The goal of this article is to show you the functions available in the itertools module as well as presenting possible use cases. The code snippets used in this article can be found on GitHub.

Table of Contents

What does the module cover?

The functions in the itertools module can be put into three groups:

  1. Infinite iterators
  2. Iterators terminating on the shortest input sequence
  3. Combinatoric iterators

Infinite iterators

Infinite iterators produce streams of infinite length (as the name suggests). Accessing them by functions or loops to truncate the streams is strongly recommended. There exist three of them in the itertools module: count(), cycle(), and repeat().


The purpose of count() is to return an iterator, that simply starts counting at a specified number. Optionally, you can provide the step size as a second parameter.

>>> from itertools import count
>>> c = count()
>>> for i in range(5):
...     print(next(c))
...
0
1
2
3
4
>>> from itertools import count
>>> c = count(10, 2)
>>> for i in range(5):
...     print(next(c))
...
10
12
14
16
18

The count() function really shines in combination with other functions. Let’s assume you have a list of names and you want to index them. This may look like this:

# itertools_count.py

from itertools import count

names = ["Alice", "Bob", "Larry", "Margaret"]
names_with_index = [name for name in zip(count(1), names)]
print(names_with_index)

Note: zip() is a built-in function returning an iterator of tuples, where the i-th tuple contains the i-th element from each of the argument sequences or iterables: zip([1, 2], [3, 4]) --> (1, 3) (2, 4)

$ python itertools_count.py
[(1, 'Alice'), (2, 'Bob'), (3, 'Larry'), (4, 'Margaret')]

cycle() iterates over all elements of an iterable, saves a copy and returns them. Once the iterable is exhausted, it continues infinitely returning the saved elements. To understand the function better, consider the following use case: You are teaching a class and for a group work, you want to divide the students into three teams. The following code snippet shows you a possible implementation of it using cycle().

# itertools_cycle.py

from itertools import cycle

names = ["Alice", "Bob", "Chris", "Larry", "Margaret", "Naomi", "Sarah"]
groups = cycle([1, 2, 3])
names_with_groups = [name for name in zip(names, groups)]

print(names_with_groups)
$ python itertools_cycle.py
[('Alice', 1), ('Bob', 2), ('Chris', 3), ('Larry', 1), ('Margaret', 2), ('Naomi', 3), ('Sarah', 1)]

The repeat function receives an object as parameter and returns it over and over again. Optionally, you can specify the number of repetitions as a second argument. Otherwise, it repeats forever. repeat() is commonly used together with the built-in map() and zip() functions. The following example is from the itertools documentation [1]. It computes the square numbers for the numbers 0-9.

>>> from itertools import repeat
>>> list(map(pow, range(10), repeat(2)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

Iterators terminating on the shortest input sequence

As the name suggests, the functions provided in this section are terminating on the shortest input sequence. In contrast to infinite iterators, these functions do not produce infinite data streams. This group is by far the largest as it contains 12 functions:

  • accumulate()
  • chain()
  • chain.from_iterable()
  • compress()
  • dropwhile()
  • filterfalse()
  • groupby()
  • islice()
  • starmap()
  • takewhile()
  • tee()
  • zip_longest()

The function accumulate() allows you to accumulate results of binary operations. By default, it uses the operator.add() function.

# itertools_accumulate.py

from itertools import accumulate
from operator import mul

numbers = [1, 2, 3, 4, 5]
result1 = accumulate(numbers)
result2 = accumulate(numbers, mul)
result3 = accumulate(numbers, initial=100)

print(f"Result 1: {list(result1)}")
print(f"Result 2: {list(result2)}")
print(f"Result 3: {list(result3)}")

In the example at hand we have a list numbers containing the numbers 1 to 5. result1 gets an iterator assigned returning the previous sum plus the next element until it is exhausted. The iterator assigned to result2 uses operator.mul() to accumulate the elements instead of operator.add(). Furthermore, the iterator assigned to result3 got an initial parameter specified. This parameter will be added at the beginning of the list and hence included in all calculations. In order to print the results, we need to consume the iterators by turning them into lists using the built-in list() function.

$ python itertools_accumulate.py
Result 1: [1, 3, 6, 10, 15]
Result 2: [1, 2, 6, 24, 120]
Result 3: [100, 101, 103, 106, 110, 115]

Sometimes, it is necessary to consume multiple iterators and/or iterables sequentially. Instead of using multiple for-loops, you can utilize chain().

# itertools_chain.py

from itertools import chain

class1 = ["Alice", "Bob", "Chris"]
class2 = ["Larry", "Margaret", "Naomi", "Sarah"]
all_people = list(chain(class1, class2))

print(f"All people: {all_people}")
$ python itertools_chain.py
All people: ['Alice', 'Bob', 'Chris', 'Larry', 'Margaret', 'Naomi', 'Sarah']

If you want to pass an iterable of iterators/iterables, you can use the alternate constructor chain.from_iterable(). The example above would change as follows:

all_people = list(chain.from_iterable([class1, class2]))

The next function we will cover is compress(). In essence, it receives two parameters: data is an iterable you want to compress and selectors is an iterable, which tells you whether the element in data is kept or dismissed.

# itertools_compress.py

from itertools import compress


def name_selection(names):
    name_selectors = []

    for name in names:
        if name.startswith("A"):
            name_selectors.append(1)
        else:
            name_selectors.append(0)

    return name_selectors


names = ["Albert", "Alexandra", "Miriam", "Sascha"]
filtered_names = list(compress(names, name_selection(names)))

print(f"Filtered names: {filtered_names}")

In this example, we define a custom selection function, which receives a list of names and checks whether a name starts with an A. If so, a 1 (basically True) is mapped, otherwise a 0. This leads to an iterator only returning names starting with an A.

$ python itertools_compress.py
Filtered names: ['Albert', 'Alexandra']

The idea behind the dropwhile() function is to drop elements as long as the specified condition is true. If the condition once becomes false, all remaining elements are returned sequentially.

>>> from itertools import dropwhile
>>> list(dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]))
[6, 4, 1]

takewhile() is the direct opposite of dropwhile() and returns the elements as long as the predicate is true.

>>> from itertools import takewhile
>>> list(takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]))
[1, 4]

The next function we are going to investigate is filterfalse(). Basically, it is the complement for the built-in filter() function. Instead of returning the element if the specified function returns true for it, it only returns elements in case the function returns false.

# itertools_filterfalse.py

from itertools import filterfalse


def is_negative(number):
    return number < 0


numbers = [-1, 0, 4, 1, -3]
positive_numbers = list(filterfalse(is_negative, numbers))

print(positive_numbers)
$ python itertools_filterfalse.py
[0, 4, 1]

The next function we will look at is groupby(). Let’s assume you have a list of data points, which consist of a group and a value. To make things easy, a data point is a tuple of the form (group, value), where both are simply integers. Then we can group them as follows:

# itertools_groupby.py

from itertools import groupby
from operator import itemgetter

data = [
    (0, 0),
    (0, 1),
    (1, 4),
    (0, 9),
    (1, 2),
    (2, 5),
    (1, 6),
]

for k, v in groupby(data, itemgetter(0)):
    print(k, list(v))

groupby() will return an iterator yielding the group/key and the corresponding values as a list. We specified the key as a second argument. In our case, we utilize operator.itemgetter and pass the first element of the data point as key/group. However, if we run the code snippet at hand, we get the following:

$ python itertools_groupby.py
0 [(0, 0), (0, 1)]
1 [(1, 4)]
0 [(0, 9)]
1 [(1, 2)]
2 [(2, 5)]
1 [(1, 6)]

You might have expected only three results with the keys 0, 1, and 2 and their corresponding value lists. However, this requires an additional preparing step. Usually, groupby() looks at the first element and appends it to the value list of the newly created group (with the key of the first item). If the second element’s key is equal to the previously used group/key, it is appended. In case it is not the same, a new group based on the element’s key is created, where the element is appended to (and so on). To get three groups (as you might have expected), we need to sort our data points before grouping them. Therefore, let’s add the following line before the for-loop:

data.sort()

Running the script again prints the desired result.

$ python itertools_groupby.py
0 [(0, 0), (0, 1), (0, 9)]
1 [(1, 2), (1, 4), (1, 6)]
2 [(2, 5)]

islice() is a function only returning selected items from a given iterable. You can think of it as a function for iterators, which basically does the very same thing as the slice-operator [] does for lists and tuples. islice() takes four parameters: iterable, start, stop, and step. The last three are equivalent to the three you can specify for the slice-operator for lists and tuples.

# itertools_islice.py

from itertools import islice

list1 = list(islice(range(50), 2))
list2 = list(islice(range(50), 40, 44))
list3 = list(islice(range(50), 5, 45, 10))

print(f"islice with stop parameter only: {list1}")
print(f"islice with start and stop: {list2}")
print(f"islice with start, stop, and step: {list3}")

Note: Again, we used the built-in list() function to consume the whole iterator and turn it into a list.

$ python itertools_islice.py
islice with stop parameter only: [0, 1]
islice with start and stop: [40, 41, 42, 43]
islice with start, stop, and step: [5, 15, 25, 35]

The starmap() function returns an iterator that executes a given function using arguments obtained from a given iterable. It is pretty similar to the built-in map() function. However, instead of constructing a tuple from multiple iterators, it splits up the items in a single iterator as arguments to the mapping function using the * syntax.

# itertools_starmap.py

from itertools import starmap


def pow_with_input(base, exponent):
    return base, exponent, pow(base, exponent)


values = [(4, 9), (1, 6), (0, 5), (3, 8), (2, 7)]

for i in starmap(pow_with_input, values):
    print("pow({}, {}) = {}".format(*i))

In the example at hand, we first define a simple function, which computes the power for a given base and exponent. Instead of only returning the result, it returns a tuple of the form (base, exponent, result). We consume the iterator returned by starmap() with a for-loop and print each element in an user-friendly way.

$ python itertools_starmap.py
pow(4, 9) = 262144
pow(1, 6) = 1
pow(0, 5) = 0
pow(3, 8) = 6561
pow(2, 7) = 128

The tee() function takes an iterable and returns independent iterators based on it. The default number of returned iterators is 2, but it can be specified as the function’s second argument. It is important to note, that the original iterable should not be used/consumed afterwards. This may lead to unexpected behavior.

# itertools_tee.py

from itertools import islice
from itertools import tee

s = islice(range(100), 3)
s1, s2 = tee(s)

print(f"First list: {list(s1)}")
print(f"Second list: {list(s2)}")
$ python itertools_tee.py
First list: [0, 1, 2]
Second list: [0, 1, 2]

Note: Iterators returned by tee() are not threadsafe. According to the docs: “A RuntimeError may be raised when using simultaneously iterators returned by the same tee() call, even if the original iterable is threadsafe.”


Last but not least we have a look at the zip_longest() function. If you use the built-in zip() function to combine two iterables, it will stop if one of both is exhausted. If you want to continue until the longer iterable is exhausted, you can utilize zip_longest() as it will fill missing values with the specified fillvalue (default is None).

# itertools_zip_longest.py

from itertools import zip_longest

a = [1, 2, 3]
b = ["One", "Two"]

result1 = list(zip(a, b))
result2 = list(zip_longest(a, b))

print(result1)
print(result2)
$ python itertools_zip_longest.py
[(1, 'One'), (2, 'Two')]
[(1, 'One'), (2, 'Two'), (3, None)]

Combinatoric iterators

The group of combinatoric iterators consists of the following four functions:

  • product()
  • permutations()
  • combinations()
  • combinations_with_replacement()

The product() function computes the cartesian product for a given list of iterables. Thus, it is the equivalent of using nested for-loops. It is also worth to note that nested loops cycle like an odometer with the rightmost element advancing on every iteration, which creates a lexicographic ordering. Granted a sorted input, that means the product tuples are emitted in sorted order. I found a pretty nice example on Doug Hellmann’s PyMOTW-3 blog [4], which is shown below:

# itertools_product.py

from itertools import chain
from itertools import product

FACE_CARDS = ("J", "Q", "K", "A")
SUITS = ("H", "D", "C", "S")

DECK = list(
    product(
        chain(range(2, 11), FACE_CARDS),
        SUITS,
    )
)

for card in DECK:
    print("{:>2}{}".format(*card), end=" ")
    if card[1] == SUITS[-1]:
        print()

It computes all cards of a standard deck by just specifying the face cards and the suits.

$ python itertools_product.py
 2H  2D  2C  2S
 3H  3D  3C  3S
 4H  4D  4C  4S
 5H  5D  5C  5S
 6H  6D  6C  6S
 7H  7D  7C  7S
 8H  8D  8C  8S
 9H  9D  9C  9S
10H 10D 10C 10S
 JH  JD  JC  JS
 QH  QD  QC  QS
 KH  KD  KC  KS
 AH  AD  AC  AS

The permutations() function generates all possible permutations for a given length r (second argument). If r is not specified, the length of each permutation is equal to the length of the iterable specified as the first argument.

# itertools_permutations.py

from itertools import permutations

l = [1, 2, 3]
result1 = list(permutations(l))
result2 = list(permutations(l, 2))

print(result1)
print(result2)
$ python itertools_permutations.py
[(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)]
[(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]

In contrast to permutations(), the combinations() function requires the r argument. Furthermore, elements are treated as unique based on their position, not on their value. So only if the input elements’ values are unique, there will be no repeated values in each combination.

# itertools_combinations.py

from itertools import combinations

l = [1, 2, 3]
m = [1 , 2, 3, 1]
result1 = list(combinations(l, 3))
result2 = list(combinations(l, 2))
result3 = list(combinations(m, 3))

print(result1)
print(result2)
print(result3)
$ python itertools_combiations.py
[(1, 2, 3)]
[(1, 2), (1, 3), (2, 3)]
[(1, 2, 3), (1, 2, 1), (1, 3, 1), (2, 3, 1)]

Unlike combinations(), the combinations_with_replacement() function does compute combinations which include repeated elements.

# itertools_combinations_with_replacement.py

from itertools import combinations_with_replacement

l = [1, 2, 3]
result1 = list(combinations_with_replacement(l, 3))
result2 = list(combinations_with_replacement(l, 2))

print(result1)
print(result2)
$ python itertools_combinations_with_replacement.py
[(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 2, 2), (1, 2, 3), (1, 3, 3), (2, 2, 2), (2, 2, 3), (2, 3, 3), (3, 3, 3)]
[(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]

More itertools

In order to extend the existing itertools tool set, you can install more-itertools [5] providing high performance functions built upon the existing ones. The package is available via pip:

$ python -m pip install more-itertools

Now, you can use functions like flatten():

# more_itertools_flatten.py

from more_itertools import flatten

nested_list = [[1, 2], [3, 4]]
flattened_list = list(flatten(nested_list))

print(flattened_list)
$ python more_itertools_flatten.py
[1, 2, 3, 4]

Summary

To sum up, we got to know all currently existing functions of the itertools module. We had a look at possible use cases and (if possible) connected them to common real world scenarios. I hope you enjoyed reading the article. Let me know what you think about it via Twitter and feel free to share it with others. Stay curious and keep coding!

References