Language: EN

comprehensions-en-python

How to Use Comprehensions in Python

Introduction

In Python, “comprehension” refers to a concise syntax for creating new sequences from existing iterables.

They can be applied to any iterable, which includes lists, sets, and dictionaries, among others. But it will also work if we create our own iterables.

Comprehensions are not only more concise in syntax but can also, at times, be faster than equivalent constructions using for loops.

The basic syntax is,

expression for element in iterable

Optionally, we can pass a filter condition like this

expression for element in iterable if condition
  • expression: This is the part that defines the result of the comprehension. It is an operation that generates the final element from the original.
  • for element in iterable: This is the loop that iterates over each element of the original iterable.
  • if condition: This is an optional condition that filters the elements of the original iterable before applying the expression.

This syntax has minimal variations depending on the object we want to generate with the comprehension (it will generally vary in whether we wrap it in [] or {})

But let’s see it with some examples 👇

List Comprehensions

List comprehensions allow creating new lists by applying an expression to each element of a sequence. The basic syntax is:

[expression for element in iterable if condition]

For example, we can apply it without a condition

# Create a list of squares of numbers from 0 to 9
squares = [x**2 for x in range(10)]

print(squares)  # Output: [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

Or in this example with a condition,

# Create a list of squares of the numbers 0 to 9 that are even
even_squares = [x**2 for x in range(10) if x % 2 == 0]

print(even_squares)  # Output: [0, 4, 16, 36, 64]

Nested Comprehensions

List comprehensions can also be nested, which is useful for working with multidimensional structures.

# Create a 3x3 identity matrix
identity_matrix = [[1 if i == j else 0 for j in range(3)] for i in range(3)]

print(identity_matrix)
# Output: [[1, 0, 0], [0, 1, 0], [0, 0, 1]]

In this example,

  • The [[...] for i in range(3)] outer part generates 3 list comprehensions.
  • The [1 if i == j else 0 for j in range(3)] inner part generates a list that has 3 elements, which are 1 if i==j.

Set Comprehensions

Set comprehensions are similar to list comprehensions, but they generate a set. The syntax is:

{expression for element in iterable if condition}

For example,

# Create a set of odd numbers from 0 to 9
odds = {x for x in range(10) if x % 2 != 0}

print(odds)  # Output: {1, 3, 5, 7, 9}

Dictionary Comprehensions

We can also use comprehensions that allow us to build dictionaries concisely. The syntax is:

{key: value for element in iterable if condition}

For example, without a condition,

# Create a dictionary with numbers and their squares for numbers from 0 to 9
squares_dict = {x: x**2 for x in range(10)}

print(squares_dict)
# Output: {0: 0, 1: 1, 2: 4, 3: 9, 4: 16, 5: 25, 6: 36, 7: 49, 8: 64, 9: 81}

Or with a condition,

# Create a dictionary of squares only for even numbers from 0 to 9
even_squares_dict = {x: x**2 for x in range(10) if x % 2 == 0}

print(even_squares_dict)
# Output: {0: 0, 2: 4, 4: 16, 6: 36, 8: 64}

Generator Comprehensions

Generators return an object that produces elements. The syntax is:

(expression for element in iterable if condition)

For example,

# Create a generator of squares of numbers from 0 to 9
squares_gen = (x**2 for x in range(10))

print(list(squares_gen))  # Output: [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

It may seem similar to the case of a list, but in this case, what we have is an iterable of generators.

This would be useful if we are working with large volumes of data that we do not need to generate and load into memory at the same time.

That is, in the list example, squares had all the values calculated. However, in the case of squares_gen, it is calculated only once each time we iterate over the collection.

In the example, they seem similar because print(list()) iterates through all the elements and prints them.