Skip to content

2021

Use pre-commit to save time and avoid mistakes

I'm working in a team of data scientists, and most of us don't have a "proper" software background. Most here have some sort of natural sciences education and have picked up machine learning and software development along the way. This means that we don't have the same software craftmanship foundation to build from when our ML models need to grow, scale, and change.

There is a lot of ways to improve in this area, but a simple one to implement for a whole team in one go is to require pre-commit installed in all projects. This is a tool that lets you define a set of checks that are performed on your code every time you make a commit in git (you are using git, right?).

Installation

Make (or copy from below) a file called .pre-commit-config.yaml and place it in the root of your repository. Then

pip install pre-commit
pre-commit install

Run

Every time you git commit the hooks you have defined in .pre-commit-config.yaml will be run on the changed files.

If for some reason you want to run the hooks on all files (for instance in your CI/CD) pipeline, you can do

pre-commit run --all-files

Individual checks

Stop dealing with whitespace diffs in your PRs

-   repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v3.2.0
    hooks:
    -   id: end-of-file-fixer
    -   id: trailing-whitespace
-   repo: https://github.com/pycqa/isort
    rev: 5.8.0
    hooks:
    - id: isort
      name: isort

The two first hooks fixes small whitespace mistakes. Each file should end with just a newline, and there should be no whitespace at the end of a line.

isort sorts your import statements. It is a minor thing, but it will group imports into 3 groups:

  1. Included in Python stdlib.
  2. Third party library.
  3. Local code.

There is some setup needed to make it compatible with black. See Full setup for details.

You probably committed this by mistake

-   repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v3.2.0
    hooks:
    -   id: check-ast
    -   id: check-json
    -   id: check-yaml
    -   id: debug-statements
    -   id: detect-aws-credentials
        args: [--allow-missing-credentials]
    -   id: detect-private-key
    -   id: check-merge-conflict
    -   id: check-added-large-files
        args: ['--maxkb=3000']

Here is a bunch of hooks that will

  • Check if your Python code is valid (avoiding those SyntaxErrors that sometimes crop up)

  • Check that json and yaml files can be parsed

  • Check that you don't have any leftover breakpoint() statements from a debugging session.

  • Check that you haven't accidentally committed secrets.

  • Check that you haven't committed an unresolved merge conflict, like leaving

    >>>>>>>>>>>>>>>>>>>>>> HEAD
    

    in the file.

  • Check that you haven't committed an unusally large file. If you actually need large files inside your repo, use git-lfs.

Make Jupyter Notebook diffs easier to deal with

-   repo: https://github.com/kynan/nbstripout
    rev: 0.5.0
    hooks:
    - id: nbstripout

nbstripout is very useful if you commit a lot of Jupyter Notebooks to your repo. The output cells are saved in the file, so if you are outputting some large plots, each notebook can become quite big. If your notebooks are not just one-off explorations, but you come back to them more than once, this will make the PR diffs much easier to read.

If that is NOT the case, maybe you don't want or need this one.

Stop arguing over code style

-   repo: https://github.com/psf/black
    rev: 21.7b0
    hooks:
    -   id: black
-   repo: https://gitlab.com/pycqa/flake8
    rev: 3.7.9
    hooks:
    - id: flake8
      additional_dependencies:
          - flake8-unused-arguments

black is a code autoformatter. It has opinions on what is good style and bad, and I mostly agree with those opinions. The very cool thing about black is that it does not just find instances where you are not following the style, it can automatically fix your code to follow the style.

flake8 is a linter. It can check more kinds style errors, but it will not fix anything. It can only complain. This is mostly fine, because it is often trivial to fix the issues that flake8 raises.

Both of these tools needs some config to work as desired. See Full setup for details.

Optional static type checking

-   repo: https://github.com/pre-commit/mirrors-mypy
    rev: v0.782
    hooks:
    -   id: mypy
        args: [--ignore-missing-imports]

You can optionally do static typing in Python now. mypy is a tool to run static analysis on your python files and it will complain if you are inputting or return types that don't match your typehints.

Full setup

If you just want to copy my setup, add these three files to the root of your repo:

.pre-commit-config.yaml
repos:
-   repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v3.2.0
    hooks:
    -   id: check-ast
    -   id: check-json
    -   id: check-yaml
    -   id: debug-statements
    -   id: detect-aws-credentials
        args: [--allow-missing-credentials]
    -   id: detect-private-key
    -   id: check-merge-conflict
    -   id: check-added-large-files
        args: ['--maxkb=3000']
-   repo: https://github.com/pre-commit/mirrors-mypy
    rev: v0.782
    hooks:
    -   id: mypy
        args: [--ignore-missing-imports]
-   repo: https://github.com/pycqa/isort
    rev: 5.8.0
    hooks:
    - id: isort
      name: isort
-   repo: https://github.com/psf/black
    rev: 21.7b0
    hooks:
    -   id: black
-   repo: https://gitlab.com/pycqa/flake8
    rev: 3.7.9
    hooks:
    - id: flake8
      additional_dependencies:
          - flake8-unused-arguments
-   repo: https://github.com/kynan/nbstripout
    rev: 0.5.0
    hooks:
    - id: nbstripout
pyproject.toml
[tool.black]
line-length = 100
include = '\.pyi?$'
exclude = '''
/(
    \.git
  | \.hg
  | \.mypy_cache
  | \.tox
  | \.venv
  | _build
  | buck-out
  | build
  | dist
)/
'''

[tool.isort]
profile = "black"
line_length = 100
.flake8
[flake8]
ignore = E203, E266, E501, W503
max-line-length = 100
max-complexity = 18
select = B,C,E,F,W,T4,B9,U100
unused-arguments-ignore-abstract-functions = True

Updates

  • 2021-09-08: Add flake8-unused-arguments.

Converting between custom dataclasses and numpy arrays

TL;DR: Implement __array__(), __len__() and __getitem__() methods on your dataclass. See the final section for a working example.

I have gotten increasingly interested in python typehints, and in a recent project I'm creating a lot of custom types to create interfaces for different modules in my application. I usually try to keep the types as standardlib python types, but the dataclass can be pretty neat.

Here is an example of a simple custom dataclass

from dataclasses import dataclass
@dataclass
class Point2D:
    x: float
    y: float

If I want a simple way to convert this to a numpy array, I run into a few stumbling blocks:

Converting one instance to a np.array (the naive way)

import numpy as np

p = Point2D(x=0.2, y=3.0)
arr = np.array(p)
print(arr, type(arr), arr.dtype)
# Point2D(x=0.2, y=3.0) <class 'numpy.ndarray'> object

I don't get the values from Point2D, I just get an array with the object inside. However, we can implement an __array__ method on Point2D that will allow numpy to produce an array with the correct dtype.

@dataclass
class Point2D:
    ...
    def __array__(self):
        return np.array([self.x, self.y])

Now we get a much more sensible result when converting

p = Point2D(x=0.2, y=3.0)
arr = np.array(p)
print(arr, type(arr), arr.dtype)
# [0.2 3. ] <class 'numpy.ndarray'> float64

The trouble comes when we want to make a new custom type that inherits from Point2D.

Inheriting the __array__ method

Let's make a simple extension of Point2D to 3 dimensions

@dataclass
class Point3D(Point2D):
    z: float

If we try to convert this into a numpy array, we run into trouble

p = Point3D(x=0.2, y=3.0, z=-1.0)
arr = np.array(p)
print(arr, type(arr), arr.dtype)
# [0.2 3. ] <class 'numpy.ndarray'> float64

We are missing the new z dimension!

One fix is to make a new __array__ method.

@dataclass
class Point3D(Point2D):
    ...
    def __array__(self):
        return np.array([self.x, self.y, self.z])

That will definitely work, but it breaks the DRY principle. Instead, we can make use of dataclasses.astuple

from dataclasses import astuple

@dataclass
class Point2D:
    x: float
    y: float
    def __array__(self):
        return np.array(astuple(self))

@dataclass
class Point3D(Point2D):
    z: float

p = Point3D(x=0.2, y=3.0, z=-1.0)
arr = np.array(p)
print(arr, type(arr), arr.dtype)
# [ 0.2  3.  -1. ] <class 'numpy.ndarray'> float64

Less repetition and less chance of mistakes. Nice.

Our next issue is when dealing with more than one instance of these custom classes at a time.

Converting lists of custom dataclasses with nested conversion

If I have a few Points, I might want a 2D np.array with all the values. The naive approach would be to do

p1 = Point3D(1, 2, 3)
p2 = Point3D(4, 5, 6)
list_of_points = [p1, p2] 
arr = np.array(list_of_points)
print(arr, type(arr), arr.dtype, arr.shape)
# [Point3D(x=1, y=2, z=3) Point3D(x=4, y=5, z=6)] <class 'numpy.ndarray'> object (2,)

Not only do I not get what I expected, I even get a bunch of warnings from numpy that this is a no-go

<input>:3: FutureWarning: The input object of type 'Point3D' is an array-like implementing one of the corresponding protocols (`__array__`, `__array_interface__` or `__array_struct__`); but not a sequence (or 0-D). In the future, this object will be coerced as if it was first converted using `np.array(obj)`. To retain the old behaviour, you have to either modify the type 'Point3D', or assign to an empty array created with `np.empty(correct_shape, dtype=object)`.
<input>:3: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.

We already know we can get a numpy array from a single instance, so we can get around this hurdle with a simple list comprehension

arr = np.array([np.array(p) for p in list_of_points])
print(arr, type(arr), arr.dtype, arr.shape)
# [[1 2 3]
# [4 5 6]] <class 'numpy.ndarray'> int32 (2, 3)

That works, but it feels more like a workaround than a real solution. Should I really have to remember to do this nested conversion every time I want to get my data in a 2D matrix?

No, if I just implement two additional methods on the base class, I don't have to think about this any more.

Converting lists of custom dataclasses with __len__ and __getitem__

from dataclasses import dataclass, astuple
import numpy as np

@dataclass
class Point2D:
    x: float
    y: float

    def __array__(self):
        return np.array(astuple(self))

    def __len__(self):
        return astuple(self).__len__()

    def __getitem__(self, item):
        return astuple(self).__getitem__(item)

@dataclass
class Point3D(Point2D):
    z: float

p1 = Point3D(1, 2, 3)
p2 = Point3D(4, 5, 6)
list_of_points = [p1, p2] 
arr = np.array(list_of_points)
print(arr, type(arr), arr.dtype, arr.shape)
# [[1 2 3]
# [4 5 6]] <class 'numpy.ndarray'> int32 (2, 3)

We are again abusing dataclass.astuple to let us access each class variable programatically, in order.

To be honest, I don't really understand why __array__ does not work for lists of custom dataclasses, but __len__ and __getitem__ does. If numpy is looping through each element one at a time to add it to an array, we might run into some performance issues at some point.

But, for now, this looks fairly clean for my taste and it is very practical.