Comprehensions and Generator expressions

This chapter will show how to use comprehensions and generator expressions for map, filter and reduce operations. You'll also learn about iterators and the yield statement.

Comprehensions

As mentioned earlier, Python provides map() and filter() built-in functions. Comprehensions provide a terser and a faster (usually) way to implement them. However, the syntax can take a while to understand and get comfortable with.

The minimal requirement for a comprehension is a mapping expression (which could include a function call) and a loop. Here's an example:

>>> nums = (321, 1, 1, 0, 5.3, 2)

# manual implementation
>>> sqr_nums = []
>>> for n in nums:
...     sqr_nums.append(n * n)
... 
>>> sqr_nums
[103041, 1, 1, 0, 28.09, 4]

# list comprehension
>>> [n * n for n in nums]
[103041, 1, 1, 0, 28.09, 4]

The general form of the above list comprehension is [expr loop]. Comparing with the manual implementation, the difference is that append() is automatically performed, which is where most of the performance benefit comes from. Note that list comprehension is defined based on the output being a list, input to the for loop can be any iterable (like tuple in the above example).

Here's an example with filtering operation. Instead of the following implementations:

# manual implementation
def remove_dunder(obj):
    names = []
    for n in dir(obj):
        if '__' not in n:
            names.append(n)
    return names

# using 'filter' function
def remove_dunder(obj):
    return list(filter(lambda n: '__' not in n, dir(obj)))

You can use comprehension syntax like this:

>>> def remove_dunder(obj):
...     return [n for n in dir(obj) if '__' not in n]
... 
>>> remove_dunder(dict)
['clear', 'copy', 'fromkeys', 'get', 'items', 'keys', 'pop',
 'popitem', 'setdefault', 'update', 'values']

The general form of the above comprehension is [expr loop condition]. If you can write the manual implementation, it is easy to derive the comprehension version. Put the expression (the argument passed to append() method) first, and then put the loops and conditions in the same order as the manual implementation. With practice, you'll be able to read and write the comprehension versions naturally.

Here's an example with zip() function:

>>> p = [1, 3, 5]
>>> q = [3, 214, 53]
>>> [i + j for i, j in zip(p, q)]
[4, 217, 58]
>>> [i * j for i, j in zip(p, q)]
[3, 642, 265]

And here's a nested loop example:

>>> names = ['Jo', 'Joe', 'Jon']
>>> pairs = []
>>> for i, n1 in enumerate(names):
...     for n2 in names[i+1:]:
...         pairs.append((n1, n2))
... 
>>> pairs
[('Jo', 'Joe'), ('Jo', 'Jon'), ('Joe', 'Jon')]
# note that the loop order is same as the manual implementation
>>> [(n1, n2) for i, n1 in enumerate(names) for n2 in names[i+1:]]
[('Jo', 'Joe'), ('Jo', 'Jon'), ('Joe', 'Jon')]

Similarly, you can build dict and set comprehensions by using {} instead of [] characters. Comprehension syntax inside () characters becomes a generator expression (discussed later in this chapter), so you'll need to use tuple() for tuple comprehension. You can use list(), dict() and set() instead of [] and {} respectively as well.

>>> marks = dict(Rahul=68, Ravi=92, Rohit=75, Rajan=85, Ram=80)
>>> {k: v for k, v in marks.items() if v >= 80}
{'Ravi': 92, 'Rajan': 85, 'Ram': 80}

>>> colors = {'teal', 'blue', 'green', 'yellow', 'red', 'orange'}
>>> {c for c in colors if 'o' in c}
{'yellow', 'orange'}

>>> dishes = ('Poha', 'Aloo tikki', 'Baati', 'Khichdi', 'Makki roti')
>>> tuple(d for d in dishes if len(d) < 6)
('Poha', 'Baati')

If you are still confused with comprehension syntax, see:

Iterator

Partial quote from docs.python glossary: iterator:

An object representing a stream of data. Repeated calls to the iterator’s __next__() method (or passing it to the built-in function next()) return successive items in the stream. When no more data are available a StopIteration exception is raised instead.

The filter() example in the previous section required further processing, such as passing to the list() function to get the output as a list object. This is because the filter() function returns an object that behaves like an iterator. You can pass iterators anywhere iterables are allowed, such as the for loop. Here's an example:

>>> filter_obj = filter(lambda n: '__' not in n, dir(tuple))
>>> filter_obj
<filter object at 0x7fd910e2de80>
>>> for x in filter_obj:
...     print(x)
... 
count
index

One of the differences between an iterable and an iterator is that you can iterate over iterables any number of times (quite the tongue twister, if I may say so myself). Also, the next() function can be used on an iterator, but not iterables. Once you have exhausted an iterator, any attempt to get another item (such as next() or for loop) will result in a StopIteration exception. Iterators are lazy and memory efficient since the results are evaluated only when needed, instead of lying around in a container.

>>> names = filter(lambda n: '__' not in n, dir(tuple))
>>> next(names)
'count'
>>> next(names)
'index'
>>> next(names)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

You can convert an iterable to an iterator using the iter() built-in function.

>>> nums = [321, 1, 1, 0, 5.3, 2]
>>> iter(nums)
<list_iterator object at 0x7fd90e7f8ee0>

Here's a practical example to get a random item from a list without repetition:

>>> import random 
>>> names = ['Jo', 'Ravi', 'Joe', 'Raj', 'Jon']
>>> random.shuffle(names)
>>> random_name = iter(names)
>>> next(random_name)
'Jon'
>>> next(random_name)
'Ravi'

yield

Functions that use yield statement instead of return to create an iterator are known as generators. Quoting from docs.python: Generators:

Each time next() is called on it, the generator resumes where it left off (it remembers all the data values and which statement was last executed).

Here's a fibonacci generator:

>>> def fibonacci(n):
...     a, b = 0, 1
...     for _ in range(n):
...         yield a
...         a, b = b, a + b
... 
>>> fibonacci(5)
<generator object fibonacci at 0x7fd90e7b22e0>
>>> list(fibonacci(10))
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]

For a more detailed discussion and related features, see:

Generator expressions

Comprehension syntax inside () characters creates an iterator, known as generator expressions. Using a generator expression is memory efficient and faster than comprehensions whenever you need a single use iterable. If you use comprehension, you'll be wasting memory to save the values in a container, only to be discarded once they are processed by a reduce operation such as the sum() function in the below examples.

>>> nums = [100, 53, 32, 0, 11, 5, 2]
>>> g = (n * n for n in nums)
>>> g
<generator object <genexpr> at 0x7fd90e7b22e0>
>>> next(g)
10000

# here's a generator version of the sum_sqr_evens(iterable) function
# note that () is optional here for the generator expression
>>> sum(n * n for n in nums if n % 2 == 0)
11028

# inner product
>>> sum(i * j for i, j in zip((1, 3, 5), (2, 4, 6)))
44

Here's an example with join() method:

>>> items = (1, 'hi', [10, 20], 'bye')
>>> ':'.join(items)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: sequence item 0: expected str instance, int found
>>> ':'.join(str(i) for i in items)
'1:hi:[10, 20]:bye'

Exercises

  • Write a function that returns a dictionary sorted by values in ascending order.

    >>> marks = dict(Rahul=86, Ravi=92, Rohit=75, Rajan=79, Ram=92)
    >>> sort_by_value(marks)
    {'Rohit': 75, 'Rajan': 79, 'Rahul': 86, 'Ravi': 92, 'Ram': 92}
    
  • Write a function that returns a list of string slices as per the following rules:

    • return the input string as the only element if its length is less than 3 characters
    • otherwise, return all slices that have 2 or more characters
    >>> word_slices('i')
    ['i']
    >>> word_slices('to')
    ['to']
    >>> word_slices('table')
    ['ta', 'tab', 'tabl', 'table', 'ab', 'abl', 'able', 'bl', 'ble', 'le']
    
  • Square even numbers and cube odd numbers. For example, [321, 1, -4, 0, 5, 2] should give you [33076161, 1, 16, 0, 125, 4] as the output.

  • Calculate sum of squares of the numbers, only if the square value is less than 50. Output for (7.1, 1, -4, 8, 5.1, 12) should be 43.01.