Overview#

In this example, we’ll begin by building a neural network, then exporting it from memory to disk in a format that Triton can use for serving it for inference. Obviously, in practice we’d like to train this network on some data between these steps, but since we’re more interested here in the inference side, we’ll pretend this training has been done elsewhere. For the sake of simplicity, we’ll build a 1D convolutional network with a single output (which might be used for e.g. for binary classification).

Once our model has been properly exported, we’ll spin up a Triton inference service which will load the model and expose it for inference via gRPC requests. We’ll then build some dummy inference data and iterate through it to build requests to send to our inference service, aggregating its responses into a timeseries of network outputs.

As we’ll see, achieving higher levels of data throughput (which we’ll measure in units of seconds of data per second, or s’ / s) for higher frequency inference sampling rates, or the rate at which we sample windows from our timeseries, requires making some adjustments to this vanilla implementation to shift the bottleneck to the GPU and get the most compute utilization possible out of it. hermes helps to make those adjustments simple, and to achieve greater scale once you’ve made them.

We’ll start with our imports. From hermes, we’ll be using

  • hermes.quiver to handle exporting our model to a format usable by Triton

  • hermes.aeriel.serve to spin up an inference service locally using Python APIs

  • hermes.aeriel.client to make requests to that inference service

  • hermes.stillwater.ServerMonitor to keep track of server-side inference metrics

import time
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from ratelimiter import RateLimiter

# our hermes imports
from hermes import quiver as qv
from hermes.aeriel.client import InferenceClient
from hermes.aeriel.serve import serve
from hermes.stillwater import ServerMonitor

# couple cheap local utilities
from src import utils, plotting

logger = utils.get_logger()
Loading BokehJS ...

Now let’s establish some parameters for both the model we’d like to build as well as for our inference time deployment.

# model parameters
NUM_IFOS = 2  # number of interferometers analyzed by our model
SAMPLE_RATE = 2048  # rate at which input data to the model is sampled
KERNEL_LENGTH = 4  # length of the input to the model in seconds

# inference parameters
INFERENCE_DATA_LENGTH = 2048  # amount of data to analyze at inference time
INFERENCE_SAMPLING_RATE = 0.25  # rate at which we'll sample input windows from the inference data
INFERENCE_RATE = 250  # seconds of data we'll try to analyze per second

# convert some of these into more useful units for slicing purposes
kernel_size = int(SAMPLE_RATE * KERNEL_LENGTH)
inference_stride = int(SAMPLE_RATE / INFERENCE_SAMPLING_RATE)
inference_data_size = int(SAMPLE_RATE * INFERENCE_DATA_LENGTH)
num_inferences = (inference_data_size - kernel_size) // inference_stride + 1

# limit the number of requests we make per second
# so that we don't overload the network or server
kernels_per_second = int(INFERENCE_RATE * INFERENCE_SAMPLING_RATE)
rate_limiter = RateLimiter(max_calls=kernels_per_second, period=1)

Now let’s build our extremely simple network

class GlobalAvgPool(torch.nn.Module):
    def forward(self, x):
        return x.mean(axis=-1)


nn = torch.nn.Sequential(
    torch.nn.Conv1d(NUM_IFOS, 8, kernel_size=7, stride=2),
    torch.nn.ReLU(),
    torch.nn.Conv1d(8, 32, kernel_size=7, stride=2),
    torch.nn.ReLU(),
    torch.nn.Conv1d(32, 64, kernel_size=7, stride=2),
    torch.nn.ReLU(),
    torch.nn.Conv1d(64, 128, kernel_size=7, stride=2),
    torch.nn.ReLU(),
    torch.nn.Conv1d(128, 256, kernel_size=7, stride=2),
    torch.nn.ReLU(),
    GlobalAvgPool(),
    torch.nn.Linear(256, 1024),
    torch.nn.ReLU(),
    torch.nn.Linear(1024, 1)
)

# INSERT TRAINING CODE HERE

Ok, now the set-up work is done. We have a “trained” neural network, and we’re ready to export it for inference. Enter hermes. One of the key concepts in as-a-service inference is the idea of a model repository, a local (or cloud-based) directory with a prescribed structure that hosts all the versions of all the models to be served for inference.

Maintaining this prescribed structure, as well as the configs that Triton needs to be able to map to named inputs and outputs, can be onerous and non-trivial, so hermes.quiver was built to take the headache out of building and maintaining these repositories. In this next step, we’ll build a model repository (clearing it beforehand in case you run this notebook multiple times), add a new entry to it for the network that we just build, then export the current version of this network to that repository as an ONNX binary which Triton can load and serve for inference.

# let's make sure we're starting with a fresh repo
repo_path = "model-repo"
utils.clear_repo(repo_path)

# initialize a blank model repository
repo = qv.ModelRepository(repo_path)
assert len(repo.models) == 0  # this attribute will get updated as we add models

# create a new entry in the repo for our model
model = repo.add("my-classifier", platform=qv.Platform.ONNX)
assert len(repo.models) == 1
assert model == repo.models["my-classifier"]

# now export our current version of the network to this entry.
# Since we haven't exported any versions of this model yet,
# Triton needs to know what names to give the inputs and
# outputs and what shapes to expect, so we have to specify
# them explicitly this first time.
# Note that -1 indicates variable length batch dimension.
model.export_version(
    nn,
    input_shapes={"hoft": (-1, NUM_IFOS, kernel_size)},
    output_names=["prob"]
)
'my-classifier/1/model.onnx'

Note that our model is associated with a Config object that describes the metadata Triton requires

model.config
name: "my-classifier"
platform: "onnxruntime_onnx"
input {
  name: "hoft"
  data_type: TYPE_FP32
  dims: -1
  dims: 2
  dims: 8192
}
output {
  name: "prob"
  data_type: TYPE_FP32
  dims: -1
  dims: 1
}

and that each model can be associated with multiple different versions corresponding to different weight values, or even wholesale different architectures. The only thing that matters is that the network represents the same input-to-output mapping:

model.versions
[1]

And with that, our model is ready to be served for inference! In the code below, we’ll use hermes.aeriel.serve to spin up a Singularity container on a single GPU (index 0) which will run a Triton inference service in the background. Once the serve context exits, the service and the container running it will both be spun-down.

Once the server is up and running, we’ll use hermes.aeriel.client to establish a client connection to it, then iterate through our dummy data and make requests using this connection. The requests are made asynchronously and the responses parsed in a callback thread. The parsed responses are made available in the main thread through a queue which can be succinctly accessed by InferenceClient.get(), which will return None if there are no responses to be returned.

Note that below, we’ll do our inference in batch sizes of 1. In principle, we could get better throughput by increasing that batch size (at the cost of some latency), but as we’ll see momentarily there are more pressing issues we need to address first.

# Start by spinning up an inference service.
# Note that this will print the singularity command
# used to start the service. This is unfortunately an
# unavoidable bug in singularity right now (if you don't
# want to receive an unnecessary warning instead), but it's
# probably good to note what is really happening under the
# hood anyway.
# The `instance` returned by the context is an object
# representing the running singularity container instance.
with serve("model-repo", gpus=[0]) as instance:
    # Do our data generation in parallel while the server
    # spins up. This will obviously be pretty fast, but
    # for more complicated data generation steps this can
    # be a good way to parallelize your efforts.
    hoft = np.random.randn(NUM_IFOS, inference_data_size).astype("float32")

    # now wait until the inference service is online and
    # ready to receive requests before we attempt to connect
    logger.info("Waiting for inference service to come online...")
    instance.wait()
    logger.info("Service ready!")

    # establish a client connection to the inference service
    # and infer the names and shapes of the inputs it expects
    client = InferenceClient(
        "localhost:8001",
        model_name="my-classifier",
        model_version=1  # can use -1 to imply latest version
    )

    # client context establishes a streaming connection to
    # the inference service. Since we're not yet streaming
    # updates but making requests individually, we don't
    # technically need this, but it's a good habit to get into.
    with client:
        # now iterate through our inference timeseries at the
        # prescribed stride and send these inputs to the server
        # for inference.
        for i in range(int(num_inferences)):
            start = i * inference_stride
            stop = start + kernel_size

            # add a dummy dimension for the batch
            kernel = hoft[:, start: stop][None]
            with rate_limiter:
                client.infer(kernel)

        # now that we've submitted all our inference requests,
        # start pulling them from the output queue as they
        # become available.
        results = []
        while len(results) < num_inferences:
            response = client.get()
            if response is not None:
                y, request_id, sequence_id = response
                results.append(y[:, 0])
        logger.info("Inference complete!")

# concatenate all the responses into a single timeseries
results = np.concatenate(results)
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest pusheena_despacito_2246
16:58:54.547    Waiting for inference service to come online...                                                    
16:59:41.155    Service ready!                                                                                     
16:59:49.216    Inference complete!                                                                                

In principle, that does it: we served a model, ran data through it, and from looking at the timestamps on the logs it looks like we roughly hit our inference rate target. Younow have everything you need to run inference-as-a-service with hermes. But there are some considerations you might think about that could help improve both network performance and throughput. For some inspiration, let’s take a look at what our timeseries of network responses looks like:

p = plotting.plot_timeseries(results, INFERENCE_SAMPLING_RATE, KERNEL_LENGTH)
plotting.show(p)

So as we should have expected, it’s just a timeseries of more or less random data. What is worth noting about this timeseries, however, is the rate at which it is sampled: 0.25 Hz means that for shorter-duration events like binary blackhole mergers, we only get to make one prediction on each event. Surely there might be some benefit to taking predictions from multiple overlapping windows containing the same event and aggregating them somewhow. Let’s increase our inference sampling rate to, say, 4 Hz and see how this impacts our throughput.

For this next round of inference, I’m going to add a little more complication up front at the expense of slightly more elegance at inference time. Rather than iterating through responses after all our inference requests have been submitted, I’ll set up a callback up front that aggregates our responses into an array in real-time in the callback thread, then returns the array once completed.

class Callback:
    def __init__(self, num_inferences):
        self.y = np.zeros((num_inferences,))

    def __call__(self, response, request_id, sequence_id):
        self.y[request_id] = response[0, 0]
        if (request_id + 1) == len(self.y):
            return self.y


# reset some of our parameters with a new inference sampling rate
INFERENCE_SAMPLING_RATE = 4
inference_stride = int(SAMPLE_RATE / INFERENCE_SAMPLING_RATE)
num_inferences = (inference_data_size - kernel_size) // inference_stride + 1
kernels_per_second = int(INFERENCE_RATE * INFERENCE_SAMPLING_RATE)

rate_limiter = RateLimiter(max_calls=kernels_per_second / 20, period=0.05)
callback = Callback(num_inferences)

# from here things will look more or less the same
with serve("model-repo", gpus=[0]) as instance:
    logger.info("Waiting for inference service to come online...")
    instance.wait()
    logger.info("Service ready!")

    # instantiate client with custom callback
    client = InferenceClient(
        "localhost:8001",
        model_name="my-classifier",
        model_version=1,
        callback=callback
    )

    with client:
        for i in range(int(num_inferences)):
            start = i * inference_stride
            stop = start + kernel_size
            kernel = hoft[:, start: stop][None]

            # pass explicit request ids this time
            # for the callback to use
            with rate_limiter:
                client.infer(kernel, request_id=i)

        # now wait until the callback returns its
        # filled out array to the client's queue
        while True:
            results = client.get()
            if results is not None:
                logger.info("Inference complete!")
                break
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest chunky_hippo_6797
16:59:51.858    Waiting for inference service to come online...                                                    
17:00:35.590    Service ready!                                                                                     
17:00:51.840    Inference complete!                                                                                

And now let’s take a look at how this timeseries looks

p = plotting.plot_timeseries(results, INFERENCE_SAMPLING_RATE, KERNEL_LENGTH)
plotting.show(p)

So things work with higher frequency inference, but a quick eyeball of our logs indicate that we’re now falling well short of our intended inference rate target (at time of writing, I’m seeing a time delta of ~43s, which translates to an inference rate of ~48 seconds of data / s). Why is that? Is our network capable of handling the desired number of requests per second, or is there something else going on?

For diagnosing these sorts of questions, Triton includes a secondary service which makes some inference metrics available (served at port 8002 by default). Clients can query this service for per-model information like cumulative time spent queueing and executing requests, as well as cumulative inference counts. The hermes.stillwater library has a ServerMonitor class which, in a separate process, queries this service and organizes the returned metrics from potentially many servers and writes them to a local log file.

Let’s run the same snippet from above with a monitor in place and take a look at how queuing latency evolves over time. If requests spend longer and longer queuing over time, then our network is acting as a bottleneck and we need to scale it up. Otherwise, something else is going wrong.

# set up a new directory just for our metrics
metrics_dir = Path("metrics")
metrics_dir.mkdir(exist_ok=True)
metrics_file = metrics_dir / "non-streaming_single-model.csv"

callback = Callback(num_inferences)
with serve("model-repo", gpus=[0]) as instance:
    logger.info("Waiting for inference service to come online...")
    instance.wait()
    logger.info("Service ready!")

    client = InferenceClient(
        "localhost:8001",
        model_name="my-classifier",
        model_version=1,
        callback=callback
    )
    monitor = ServerMonitor(
        model_name="my-classifier",
        ips="localhost",
        filename=metrics_file,
        model_version=1,
        name="monitor",
        rate=4
    )

    with client, monitor:
        for i in range(int(num_inferences)):
            start = i * inference_stride
            stop = start + kernel_size
            kernel = hoft[:, start: stop][None]

            # for various reasons, the rate limiter won't work
            # here, so we'll be generous and insert a sleep
            # for half as long as the actual rate would be
            client.infer(kernel, request_id=i)
            time.sleep(0.5 / INFERENCE_SAMPLING_RATE / INFERENCE_RATE)

        while True:
            results = client.get()
            if results is not None:
                logger.info("Inference complete!")
                break
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest grated_leg_5296
17:00:56.199    Waiting for inference service to come online...                                                    
17:01:37.029    Service ready!                                                                                     
17:01:47.234    Inference complete!                                                                                

Now let’s load in the CSV that our monitor produced and take a look at the data it captured.

df = pd.read_csv(metrics_file)
df
timestamp ip model count queue compute_input compute_infer compute_output request
0 1.663891e+09 localhost my-classifier 87 587301553 9529 6885885 784 594202569
1 1.663891e+09 localhost my-classifier 232 1557675965 12471 73676 1842 1557775318
2 1.663891e+09 localhost my-classifier 331 2145661313 22275 7158593 2987 2152862687
3 1.663891e+09 localhost my-classifier 465 2857294202 24292 146878 3960 2857492663
4 1.663891e+09 localhost my-classifier 445 2551900100 23089 148376 3760 2552098128
5 1.663891e+09 localhost my-classifier 459 2444907900 24595 146049 3668 2445105272
6 1.663891e+09 localhost my-classifier 457 2241774583 24063 146711 4024 2241971869
7 1.663891e+09 localhost my-classifier 392 1770680539 22459 148572 3507 1770877783
8 1.663891e+09 localhost my-classifier 417 1734183397 22599 149888 3640 1734381239
9 1.663891e+09 localhost my-classifier 441 1625196357 23455 148051 3652 1625394260
10 1.663891e+09 localhost my-classifier 218 728727283 20618 153130 3996 728925976
11 1.663891e+09 localhost my-classifier 441 1351631666 24599 145017 3758 1351827770
12 1.663891e+09 localhost my-classifier 472 1249338573 23598 148549 3708 1249536293
13 1.663891e+09 localhost my-classifier 293 674164299 39172 133387 3226 674361466
14 1.663891e+09 localhost my-classifier 149 322752920 42381 150264 3191 322972130
15 1.663891e+09 localhost my-classifier 386 766527339 44580 133665 3243 766730229
16 1.663891e+09 localhost my-classifier 415 674953706 30304 137799 3660 675148432
17 1.663891e+09 localhost my-classifier 457 545564093 23139 147243 3520 545758828
18 1.663891e+09 localhost my-classifier 469 290677762 22557 148155 3163 290873176
19 1.663891e+09 localhost my-classifier 302 37216007 14891 98494 2104 37345551
20 1.663891e+09 localhost my-classifier 66 3553 4849 31794 816 45876
21 1.663891e+09 localhost my-classifier 67 3477 3777 25752 725 37867
22 1.663891e+09 localhost my-classifier 64 7363 4351 24825 657 41130
23 1.663891e+09 localhost my-classifier 74 3021 4662 36788 714 49819
24 1.663891e+09 localhost my-classifier 76 3660 4521 29046 674 42843
25 1.663891e+09 localhost my-classifier 70 4036 4467 31954 700 45591
26 1.663891e+09 localhost my-classifier 47 8670 4569 32898 688 51345
27 1.663891e+09 localhost my-classifier 82 94784 4455 29302 748 133859
28 1.663891e+09 localhost my-classifier 82 4372 4010 27464 611 40991
29 1.663891e+09 localhost my-classifier 55 17061 5924 48886 935 79037
30 1.663891e+09 localhost my-classifier 33 27969 4931 37830 1056 80273
31 1.663891e+09 localhost my-classifier 75 6033 5107 33550 812 51299

The queue, compute_input, compute_infer, and compute_output columns all represent different steps in the inference compute of a single request. The values in each column represent the cumulative microseconds spent on each step over all of the requests computed between each ping to the metrics service, the number of which is indicated by the count column. Knowing this, let’s reframe some of the info in these columns in a more useful fashion for our purposes.

# we'll be doing this a lot, so let's record it as a function
def cleanup_df(df, t0):
    df["Time since start (s)"] = df["timestamp"] - t0
    df["Average queue time (us)"] = df["queue"] / df["count"]

    # count all inference steps as a single metric
    # of inference latency
    infer_time = df[[f"compute_{i}" for i in ["input", "infer", "output"]]].sum(axis=1)
    df["Average infer time (us)"] = infer_time / df["count"]

    # use the number of inferences completed in an interval
    # along with the inference sampling rate to put throughput
    # in units of data seconds per second
    df["Throughput (s' / s)"] = df["count"] / df["timestamp"].diff() / INFERENCE_SAMPLING_RATE

    return df[[
        "Time since start (s)",
        "Throughput (s' / s)",
        "Average queue time (us)",
        "Average infer time (us)"
    ]]

df = cleanup_df(df, df.timestamp.min())
df
Time since start (s) Throughput (s' / s) Average queue time (us) Average infer time (us)
0 0.000000 NaN 6.750593e+06 79266.643678
1 0.101490 571.482073 6.714121e+06 379.262931
2 0.203203 813.569839 6.482360e+06 21703.489426
3 0.304591 1146.577559 6.144719e+06 376.623656
4 0.406262 1094.220999 5.734607e+06 393.764045
5 0.507635 1131.956029 5.326597e+06 379.764706
6 0.609230 1124.566687 4.905415e+06 382.490153
7 0.711257 960.526137 4.517042e+06 445.250000
8 0.812940 1025.250927 4.158713e+06 422.366906
9 0.914786 1082.516957 3.685253e+06 397.183673
10 1.017116 532.588316 3.342786e+06 815.339450
11 1.118904 1083.138178 3.064924e+06 393.138322
12 1.220450 1162.026193 2.646904e+06 372.574153
13 1.323272 712.400683 2.300902e+06 599.948805
14 1.441036 316.308677 2.166127e+06 1314.335570
15 1.542648 949.692359 1.985822e+06 470.176166
16 1.644157 1022.078312 1.626394e+06 413.886747
17 1.745598 1126.274128 1.193795e+06 380.529540
18 1.847353 1152.278619 6.197820e+05 370.735608
19 1.959072 675.803330 1.232318e+05 382.413907
20 2.060756 162.267268 5.383333e+01 567.560606
21 2.162252 165.030002 5.189552e+01 451.552239
22 2.263905 157.399162 1.150469e+02 466.140625
23 2.366733 179.912041 4.082432e+01 569.783784
24 2.468355 186.966758 4.815789e+01 450.539474
25 2.572194 168.530257 5.765714e+01 530.300000
26 2.674086 115.317916 1.844681e+02 811.808511
27 2.775749 201.646866 1.155902e+03 420.792683
28 2.877061 202.346339 5.331707e+01 391.280488
29 2.982984 129.811063 3.102000e+02 1013.545455
30 3.084848 80.990261 8.475455e+02 1327.787879
31 3.186673 184.139457 8.044000e+01 526.253333

Now let’s plot the queue time and throughput as a function of time, and consider what this data tells us:

p = plotting.plot_inference_metrics_vs_time(my_classifer=df)
plotting.show(p)

Counterintuitively, our queue time starts off really high, then goes down over time! Moreover, the length of the x-axis, around 3-4 seconds, doesn’t match up with the 10 second time delta we see between our logs above. What’s going on here?

It turns out that the first couple of inferences can often take substantially longer than the rest, as Triton tries to optimize the compute kernels used to actually perform inference. So while the first request is taking several seconds to process, an enormous queue builds up while we inundate the server with follow-ups. Meanwhile, the ServerMonitor doesn’t start saving metrics until the server-side metrics service indicates that at least one inference has completed, which explains why we see a much shorter observation window.

Once that first request is processed, our throughput skyrockets while the network tears through data as fast as it can (in fact this gives a decent estimate of the network’s individual throughput cap: around 1200 s’/s). However, we aren’t supplying new data to the network fast enough to keep it saturated, so the queue time and throughput both eventually settle down to some roughly constant values. Unlike the first request issue, which can be solved trivially by introducing a block until the first response has been received, this is a real problem. If we can’t get data to Triton fast enough now, with a relatively simple set up, then we’ll get no benefit from the myriad ways Triton, and inference-as-a-service more broadly, make scaling up easy.

So why is this issue occurring? Well, think about what happened when we went from INFERENCE_SAMPLING_RATE = 0.25 to INFERENCE_SAMPLING_RATE = 4: we’re continuing to send KERNEL_LENGTH-long windows of data to the inference service with each request, but now we’re doing it 16 times as often per second, with largely redundant data! It’s this network I/O that’s bottlenecking our pipeline now.

To alleviate this, we’ll need to build a model on the server-side which can cache data we’ve already sent, and use it to along with updates of new data to build the windows we need and pass them along to our downstream model. hermes.quiver has built-in support for constructing such a model ensemble by adding in a “snapshotter” model on the front-end of the server which maintains the state of the most recent input snapshot.

# add a new meta-model to the repository that organizes
# graphs of existing models to pass outputs from one
# as inputs to the next
ensemble = repo.add("streaming-classifier", platform=qv.Platform.ENSEMBLE)

# insert a snapshotter model at the front of this ensemble
# whose output will be passed to the input of my-classifier
classifier = repo.models["my-classifier"]
ensemble.add_streaming_inputs(
    classifier.inputs["hoft"],
    stream_size=inference_stride,
    batch_size=1
)

# we know our first request takes ~12s, so make
# sure that the snapshotter will maintain a state
# for longer than this
snapshotter = repo.models["snapshotter"]
snapshotter.config.sequence_batching.max_sequence_idle_microseconds = int(25 * 10**6)
snapshotter.config.write()

# mark the output of our classifier as the output
# of the whole ensemble and then export a "version"
# of the ensemble (basically just writes its config)
ensemble.add_output(classifier.outputs["prob"])
ensemble.export_version(None)
'streaming-classifier/1/model.empty'

Now we can run inference on our streaming model and only send the updates we need for each request, rather than the entire window!

# keep our inference results from earlier to
# verify that this implementation comes out the same
nonstreaming_results = results

# quick addition to the callback to handle waiting
# for the first couple responses to come back
class BlockingCallback(Callback):
    def block(self, i):
        while self.y[i] == 0:
            time.sleep(1e-3)

# we need to do more inferences this time, since
# the snapshot state gets initialized to 0s, so the
# first KERNEL_LENGTH * INFERENCE_SAMPLING_RATE updates
# just function to fill the snapshot out
num_inferences = inference_data_size // inference_stride
callback = BlockingCallback(num_inferences)

metrics_file = metrics_dir / "streaming_single-model.csv"
with serve("model-repo", gpus=[0]) as instance:
    logger.info("Waiting for inference service to come online...")
    instance.wait()
    logger.info("Service ready!")

    client = InferenceClient(
        "localhost:8001",
        model_name="streaming-classifier",
        model_version=1,
        callback=callback
    )
    monitor = ServerMonitor(
        model_name="streaming-classifier",
        ips="localhost",
        filename=metrics_file,
        model_version=1,
        name="monitor"
    )

    with client, monitor:
        for i in range(int(num_inferences)):
            start = i * inference_stride
            stop = start + inference_stride  # note the smaller slice
            kernel = hoft[:, start: stop]

            # provide some additional information to
            # the inference server to allow us to keep
            # track of multiple different streams
            with rate_limiter:
                client.infer(
                    kernel,
                    request_id=i,
                    sequence_id=1001,
                    sequence_start=i == 0,
                    sequence_end=(i + 1) == num_inferences
                )

            if i < 3:
                callback.block(i)
            if i == 2:
                logger.info("First 3 requests completed")

        while True:
            results = client.get()
            if results is not None:
                logger.info("Inference complete!")
                break

# ditch some of the initial inferences, which took
# place on a 0-initialized kernel with some updates
# placed at the end of it
results = results[int(KERNEL_LENGTH * INFERENCE_SAMPLING_RATE) - 1:]

# validate that this implementation produces the
# same resluts as the original
assert (results == nonstreaming_results).all()
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest expensive_milkshake_3248
17:01:50.692    Waiting for inference service to come online...                                                    
17:02:30.022    Service ready!                                                                                     
17:02:37.252    First 3 requests completed                                                                         
17:02:45.639    Inference complete!                                                                                
p = plotting.plot_timeseries(results, INFERENCE_SAMPLING_RATE, KERNEL_LENGTH)
plotting.show(p)
df = pd.read_csv(metrics_file)
dfs = {}
for model, subdf in df.groupby("model"):
    subdf = cleanup_df(subdf.reset_index(), df.timestamp.min())
    dfs[model] = subdf

p = plotting.plot_inference_metrics_vs_time(**dfs)
plotting.show(p)

So this looks like a pretty stable configuration, which means we can up our inference rate until we start to hit a bottleneck. Let’s hop up to 400 and see what happens.

INFERENCE_RATE = 400
kernels_per_second = int(INFERENCE_RATE * INFERENCE_SAMPLING_RATE)
rate_limiter = RateLimiter(max_calls=kernels_per_second / 20, period=0.05)
callback = BlockingCallback(num_inferences)

metrics_file = metrics_dir / "streaming_rate-400_single-model.csv"
with serve("model-repo", gpus=[0]) as instance:
    logger.info("Waiting for inference service to come online...")
    instance.wait()
    logger.info("Service ready!")

    client = InferenceClient(
        "localhost:8001",
        model_name="streaming-classifier",
        model_version=1,
        callback=callback
    )
    monitor = ServerMonitor(
        model_name="streaming-classifier",
        ips="localhost",
        filename=metrics_file,
        model_version=1,
        name="monitor",
        rate=4,
    )

    with client, monitor:
        for i in range(int(num_inferences)):
            start = i * inference_stride
            stop = start + inference_stride
            kernel = hoft[:, start: stop]

            with rate_limiter:
                client.infer(
                    kernel,
                    request_id=i,
                    sequence_id=1001,
                    sequence_start=i == 0,
                    sequence_end=(i + 1) == num_inferences
                )

            if i < 3:
                callback.block(i)
            if i == 2:
                logger.info("First 3 requests completed")

        while True:
            results = client.get()
            if results is not None:
                logger.info("Inference complete!")
                break
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest arid_frito_0916
17:02:48.609    Waiting for inference service to come online...                                                    
17:03:27.587    Service ready!                                                                                     
17:03:34.698    First 3 requests completed                                                                         
17:05:05.904    Inference complete!                                                                                
df = pd.read_csv(metrics_file)
dfs = {}
for model, subdf in df.groupby("model"):
    subdf = cleanup_df(subdf.reset_index(), df.timestamp.min())
    dfs[model] = subdf

p = plotting.plot_inference_metrics_vs_time(**dfs)
plotting.show(p)

So it looks like our queue latencies are still reasonably stable, but our throughput ends up tanking early in our inference run. The likely culprit is that we’ve saturated the rate at which our client can generate requests: gRPC’s asynchronous requests are implemented with Python threads in a way that doesn’t lend itself to doing strenuous work in the main thread which generates the requests. We can improve throughput further by sending longer updates at a lower rate, then using the caching model on the server to build batches of kernels to send to the downstream model.

Let’s export a new ensemble and associated snapshotter model that expects batched updates, then run inference with that.

batched_ensemble = repo.add("batched-streaming-classifier", platform=qv.Platform.ENSEMBLE)

# right now we can only handle batches small
# enough that the update isn't longer than
# the kernel itself, so we'll use the biggest
# batch we can with that constraint
batch_size = int(KERNEL_LENGTH * INFERENCE_SAMPLING_RATE)
classifier = repo.models["my-classifier"]
batched_ensemble.add_streaming_inputs(
    classifier.inputs["hoft"],
    stream_size=inference_stride,
    batch_size=batch_size,
    name="batched-snapshotter"
)

# we know our first request takes ~12s, so make
# sure that the snapshotter will maintain a state
# for longer than this
snapshotter = repo.models["batched-snapshotter"]
snapshotter.config.sequence_batching.max_sequence_idle_microseconds = int(25 * 10**6)
snapshotter.config.write()

# mark the output of our classifier as the output
# of the whole ensemble and then export a "version"
# of the ensemble (basically just writes its config)
batched_ensemble.add_output(classifier.outputs["prob"])
batched_ensemble.export_version(None)
'batched-streaming-classifier/1/model.empty'
# make a new callback that's better equipped to
# slice out a batch of responses
class BatchedCallback(BlockingCallback):
    def __init__(self, num_inferences, batch_size):
        self.y = np.zeros((num_inferences * batch_size,))
        self.batch_size = batch_size

    def __call__(self, response, request_id, sequence_id):
        start = request_id * self.batch_size
        self.y[start: start + len(response)] = response[:, 0]
        if (start + len(response) + 1) >= len(self.y):
            return self.y


num_kernels = inference_data_size // inference_stride
num_inferences = num_kernels // batch_size
callback = BatchedCallback(num_inferences, batch_size)

batches_per_second = int(INFERENCE_RATE * INFERENCE_SAMPLING_RATE / batch_size)
rate_limiter = RateLimiter(max_calls=batches_per_second / 20, period=0.05)

metrics_file = metrics_dir / "batched-streaming_single-model.csv"

with serve("model-repo", gpus=[0]) as instance:
    logger.info("Waiting for inference service to come online...")
    instance.wait()
    logger.info("Service ready!")

    client = InferenceClient(
        "localhost:8001",
        model_name="batched-streaming-classifier",
        model_version=1,
        callback=callback
    )
    monitor = ServerMonitor(
        model_name="batched-streaming-classifier",
        ips="localhost",
        filename=metrics_file,
        model_version=1,
        name="monitor"
    )

    with client, monitor:
        for i in range(int(num_inferences)):
            start = i * inference_stride * batch_size
            stop = (i + 1) * inference_stride * batch_size
            kernel = hoft[:, start: stop]

            with rate_limiter:
                client.infer(
                    kernel,
                    request_id=i,
                    sequence_id=1001,
                    sequence_start=i == 0,
                    sequence_end=(i + 1) == num_inferences
                )

            if i < 3:
                callback.block(i)
            if i == 2:
                logger.info("First 3 requests completed")

        while True:
            results = client.get()
            if results is not None:
                logger.info("Inference complete!")
                break
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest wobbly_hippo_8793
17:05:09.062    Waiting for inference service to come online...                                                    
17:05:52.971    Service ready!                                                                                     
17:05:59.821    First 3 requests completed                                                                         
17:06:05.004    Inference complete!                                                                                
df = pd.read_csv(metrics_file)
dfs = {}
for model, subdf in df.groupby("model"):
    subdf = cleanup_df(subdf.reset_index(), df.timestamp.min())
    subdf["Throughput (s' / s)"] *= batch_size
    dfs[model] = subdf

p = plotting.plot_inference_metrics_vs_time(**dfs)
plotting.show(p)

Ok, so we’re back to a stable system. Let’s try pushing this a bit farther, and use a bit longer of a segment to be able to get a decent estimate of our metrics.

INFERENCE_DATA_LENGTH = 32768
INFERENCE_RATE = 4096

inference_data_size = INFERENCE_DATA_LENGTH * SAMPLE_RATE
hoft = np.random.randn(NUM_IFOS, inference_data_size).astype("float32")

num_kernels = inference_data_size // inference_stride
num_inferences = num_kernels // batch_size
callback = BatchedCallback(num_inferences, batch_size)

batches_per_second = int(INFERENCE_RATE * INFERENCE_SAMPLING_RATE / batch_size)
rate_limiter = RateLimiter(max_calls=batches_per_second / 20, period=0.05)

metrics_file = metrics_dir / "batched-streaming_rate-4096_single-model.csv"

with serve("model-repo", gpus=[0]) as instance:
    logger.info("Waiting for inference service to come online...")
    instance.wait()
    logger.info("Service ready!")

    client = InferenceClient(
        "localhost:8001",
        model_name="batched-streaming-classifier",
        model_version=1,
        callback=callback
    )
    monitor = ServerMonitor(
        model_name="batched-streaming-classifier",
        ips="localhost",
        filename=metrics_file,
        model_version=1,
        name="monitor"
    )

    with client, monitor:
        for i in range(int(num_inferences)):
            start = i * inference_stride * batch_size
            stop = (i + 1) * inference_stride * batch_size
            kernel = hoft[:, start: stop]

            with rate_limiter:
                client.infer(
                    kernel,
                    request_id=i,
                    sequence_id=1001,
                    sequence_start=i == 0,
                    sequence_end=(i + 1) == num_inferences
                )

            if i < 3:
                callback.block(i)
            if i == 2:
                logger.info("First 3 requests completed")

        while True:
            results = client.get()
            if results is not None:
                logger.info("Inference complete!")
                break
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest joyous_signal_7729
17:06:14.401    Waiting for inference service to come online...                                                    
17:06:57.836    Service ready!                                                                                     
17:07:04.844    First 3 requests completed                                                                         
17:07:59.569    Inference complete!                                                                                
df = pd.read_csv(metrics_file)
dfs = {}
for model, subdf in df.groupby("model"):
    subdf = cleanup_df(subdf.reset_index(), df.timestamp.min())
    subdf["Throughput (s' / s)"] *= batch_size
    dfs[model] = subdf

p = plotting.plot_inference_metrics_vs_time(**dfs)
plotting.show(p)