Quite often Network Engineers will need to automate tasks to ensure that their infrastructure is well-understood (by systems or humans), consistent (in configuration), or perhaps even just more easily deployed. Whatever the use-case may be, we can all agree that automation is incredibly useful.

In this post, I am going to introduce you to one simple trick which you can add to your Python scripts which will massively reduce the amount of time you spend waiting for your automation to do its thing across your entire infrastructure. The trick is: Multithreading. I don't intend on going in great levels of detail here, but I will try my best to give you enough of an introduction to give you an idea of what's going on.

Multithreading is a way of breaking your program up into multiple individual sets of instructions which your CPU can execute at the same time. This is known as concurrency - two or more sets of instructions running side-by-side, not waiting for one another to complete.

I will use an example based on Netmiko. The example will SSH into a series of Cisco routers and gather their versions using the show version command. I have created a file called ips.txt which contains the list of router IPs I will SSH into. No examples in this post will include error handling since this is purely focusing on multithreading. You should make use of try/except, and checking for NoneType throughout your code.

The following example will do the following:
- Read all the IP addresses from ips.txt
- For each IP, connect via SSH.
- For each IP, fetch the version.
- For each IP, print the version.
- For each IP, disconnect from SSH.

from netmiko import Netmiko

def get_ip_list():
    with open("ips.txt", "r") as f:
        return f.readlines()

def get_version(ssh):
    ver = ssh.send_command("show version")

    for line in ver.split("\n"):
        # Search for the line with the format:
        #  Version      : x.x.xx
        if not line.startswith(" Version"):
             # If we didn't find it, skip this line.
            continue
        # If we found it, split the line and grab the third element.
        return line.split()[2]

    return None # In case we don't get a version.

username = "<username>"
password = "<password>"

for _ip in get_ip_list():
    ip = _ip.strip()

    # Use Netmiko to connect to our IOS XR device via SSH.
    ssh = Netmiko(host=ip, username=username, password=password, device_type="cisco_xr")

    # Print out the router's software version.
    print(ip, "is running version", get_version(ssh))
        
    # Disconnect from SSH.
    ssh.disconnect()

Using the command-line tool time, we can accurately measure how long this process will take to run. In my test, I am using 20 Cisco IOS XR routers, via OpenVPN, over a standard residential connection.

$ time py3 app.py
xx.yy.zz.aa is running version x.x.xx
xx.yy.zz.bb is running version x.x.xx
xx.yy.zz.cc is running version x.x.xx
xx.yy.zz.dd is running version x.x.xx
xx.yy.zz.ee is running version x.x.xx
xx.yy.zz.ff is running version x.x.xx
xx.yy.zz.gg is running version x.x.xx
xx.yy.zz.hh is running version x.x.xx
xx.yy.zz.ii is running version x.x.xx
xx.yy.zz.jj is running version x.x.xx
xx.yy.zz.kk is running version x.x.xx
xx.yy.zz.ll is running version x.x.xx
xx.yy.zz.mm is running version x.x.xx
xx.yy.zz.nn is running version x.x.xx
xx.yy.zz.oo is running version x.x.xx
xx.yy.zz.pp is running version x.x.xx
xx.yy.zz.qq is running version x.x.xx
xx.yy.zz.rr is running version x.x.xx
xx.yy.zz.ss is running version x.x.xx
xx.yy.zz.tt is running version x.x.xx
python3.8 app.py  0.53s user 0.16s system 0% cpu 2:17.16 total

As you can see, this took about 2 minutes and 17 seconds to complete. If you had 200 routers, that could potentially be up, or in excess of, 20 minutes! Not fun at all, especially if your equipment is older.

So... what's the solution? Well, I'm sure you guessed it: Multithreading. Multithreading should cut this down significantly... but, before I modify the above code sample, I will first provide you with a brief introduction to how exactly we multithread in Python, and some considerations to be had.

The first thing we must do is import the threading module. This allows us to start making new threads. Once we've done that, we can make use of threading.Thread to craft ourselves some threads.

import threading
import time

def say_two():
    # Wait a little bit, as if we're waiting for an API to reply.
    time.sleep(2)
    print("Two")

print("One")
thread = threading.Thread(target=say_two)
thread.start()
print("Three")

Breaking down this code, we can see that we start a new thread, and point it towards out say_two function. This function simply waits for 2 seconds, as if it is awaiting a response from an API, or SSHing into a device, then it prints two.

One consideration I want to point out here is that threading happens concurrently... meaning that maintaining order of operations can become a little more troublesome. Just think for a second about what the output of this program will be when executed, and compare it to below.

$ py3 thread-test.py 
One
Three
Two

You'll see that "Two" came last. This is because print("Three") was running while we were still waiting for time.sleep(2) in another thread to complete. Note that your program, by default, will run in a thread itself. So now we have two threads running side-by-side, and one happens to complete before the other. Python will not exit until all threads have finished execution.

So how do we ensure all threads have completed execution before printing "Three" in this case? Well, that's simple! Every thread you wish to wait for, just call .join() on it. This will ensure that the thread your code is running in awaits the thread which you are calling .join() on before continuing. Take a look at the next example.

import threading
import time

def say_two():
    # Wait a little bit, as if we're waiting for an API to reply.
    time.sleep(2)
    print("Two")

print("One")
thread = threading.Thread(target=say_two)
thread.start()
thread.join()
print("Three")

Here we can see that we first call thread.start(), telling Python that it should begin the execution of our new thread's code. We then call thread.join(), signalling to Python that the main thread must sit tight until our new thread has stopped doing what it is doing. The example below demonstrates this.

$ py3 thread-test.py
One
Two
Three

So, going back to our original example with SSH, let's take a look at how we may implement this across several threads. Another important thing to bear in mind is that you don't want too many threads. That's easily something I could make a topic of alone, so I will spare you the detail here, but the idea is that you'll want to bear in mind what your system is capable of, and ensure you don't throw too much at it at once. No system, or human for that matter, enjoys being overworked.

from netmiko import Netmiko
from getpass import getpass
import threading

def get_ip_list():
    with open("ips.txt", "r") as f:
        return f.readlines()

def get_version(ssh):
    ver = ssh.send_command("show version")
    for line in ver.split("\n"):
        # Search for the line with the format:
        #  Version      : x.x.xx
        if not line.startswith(" Version"): continue # If we didn't find it, skip this line.
        return line.split()[2] # If we found it, split the line and grab the third element.
    return None # In case we don't get a version.

def print_version(ip, username, password):
    # Use Netmiko to connect to our IOS XR device via SSH.
    ssh = Netmiko(host=ip, username=username, password=password, device_type="cisco_xr")

    # Print out the router's software version.
    print(ip, "is running version", get_version(ssh))

    # Disconnect from SSH.
    ssh.disconnect()

username = "<username>"
password = "<password>"

for _ip in get_ip_list():
    ip = _ip.strip()
    
    # Spawn (create) a new thread, passing our function the IP address, username, and password for SSH.
    thread = threading.Thread(target=print_version, args=(ip,username,password))
    thread.start()

In this slightly longer example you can see the addition of:
- The threading module.
- The print_version function, which accepts ip, username, password as the parameters.
- The thread inside our for loop at the bottom.

This simple upgrade will spawn a new thread for each IP address in our list. In my case, this is 20. If you have a larger infrastructure, you may wish to break this up into smaller groups to avoid overloading your system.

Once all the threads have finished execution, the program will terminate as normal. Let's see what sort of improvement we got by adding just a few lines of code!

$ time py3 app.py
xx.yy.zz.aa is running version x.x.xx
xx.yy.zz.bb is running version x.x.xx
xx.yy.zz.cc is running version x.x.xx
xx.yy.zz.dd is running version x.x.xx
xx.yy.zz.ee is running version x.x.xx
xx.yy.zz.ff is running version x.x.xx
xx.yy.zz.gg is running version x.x.xx
xx.yy.zz.hh is running version x.x.xx
xx.yy.zz.ii is running version x.x.xx
xx.yy.zz.jj is running version x.x.xx
xx.yy.zz.kk is running version x.x.xx
xx.yy.zz.ll is running version x.x.xx
xx.yy.zz.mm is running version x.x.xx
xx.yy.zz.nn is running version x.x.xx
xx.yy.zz.oo is running version x.x.xx
xx.yy.zz.pp is running version x.x.xx
xx.yy.zz.qq is running version x.x.xx
xx.yy.zz.rr is running version x.x.xx
xx.yy.zz.ss is running version x.x.xx
xx.yy.zz.tt is running version x.x.xx
python3.8 app.py  0.45s user 0.20s system 8% cpu 7.171 total

What a great improvement! We've gone from 2 minutes 17 seconds down to just 7 seconds. That is fantastic. You'll notice that compared to last time, our CPU usage has risen up to 8% (it was 0% before). This is what you'll want to keep an eye on, since spawning more threads will hurt your CPU more since it has to work harder to do many things at once.

I hope this has been a helpful guide, and I welcome your discussion down below if you have any further questions.

Happy automating!