Overview
Overview#
In this example, we’ll begin by building a neural network, then exporting it from memory to disk in a format that Triton can use for serving it for inference. Obviously, in practice we’d like to train this network on some data between these steps, but since we’re more interested here in the inference side, we’ll pretend this training has been done elsewhere. For the sake of simplicity, we’ll build a 1D convolutional network with a single output (which might be used for e.g. for binary classification).
Once our model has been properly exported, we’ll spin up a Triton inference service which will load the model and expose it for inference via gRPC requests. We’ll then build some dummy inference data and iterate through it to build requests to send to our inference service, aggregating its responses into a timeseries of network outputs.
As we’ll see, achieving higher levels of data throughput (which we’ll measure in units of seconds of data per second, or s’ / s) for higher frequency inference sampling rates, or the rate at which we sample windows from our timeseries, requires making some adjustments to this vanilla implementation to shift the bottleneck to the GPU and get the most compute utilization possible out of it. hermes helps to make those adjustments simple, and to achieve greater scale once you’ve made them.
We’ll start with our imports. From hermes, we’ll be using
hermes.quiverto handle exporting our model to a format usable by Tritonhermes.aeriel.serveto spin up an inference service locally using Python APIshermes.aeriel.clientto make requests to that inference servicehermes.stillwater.ServerMonitorto keep track of server-side inference metrics
import time
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from ratelimiter import RateLimiter
# our hermes imports
from hermes import quiver as qv
from hermes.aeriel.client import InferenceClient
from hermes.aeriel.serve import serve
from hermes.stillwater import ServerMonitor
# couple cheap local utilities
from src import utils, plotting
logger = utils.get_logger()
Now let’s establish some parameters for both the model we’d like to build as well as for our inference time deployment.
# model parameters
NUM_IFOS = 2 # number of interferometers analyzed by our model
SAMPLE_RATE = 2048 # rate at which input data to the model is sampled
KERNEL_LENGTH = 4 # length of the input to the model in seconds
# inference parameters
INFERENCE_DATA_LENGTH = 2048 # amount of data to analyze at inference time
INFERENCE_SAMPLING_RATE = 0.25 # rate at which we'll sample input windows from the inference data
INFERENCE_RATE = 250 # seconds of data we'll try to analyze per second
# convert some of these into more useful units for slicing purposes
kernel_size = int(SAMPLE_RATE * KERNEL_LENGTH)
inference_stride = int(SAMPLE_RATE / INFERENCE_SAMPLING_RATE)
inference_data_size = int(SAMPLE_RATE * INFERENCE_DATA_LENGTH)
num_inferences = (inference_data_size - kernel_size) // inference_stride + 1
# limit the number of requests we make per second
# so that we don't overload the network or server
kernels_per_second = int(INFERENCE_RATE * INFERENCE_SAMPLING_RATE)
rate_limiter = RateLimiter(max_calls=kernels_per_second, period=1)
Now let’s build our extremely simple network
class GlobalAvgPool(torch.nn.Module):
def forward(self, x):
return x.mean(axis=-1)
nn = torch.nn.Sequential(
torch.nn.Conv1d(NUM_IFOS, 8, kernel_size=7, stride=2),
torch.nn.ReLU(),
torch.nn.Conv1d(8, 32, kernel_size=7, stride=2),
torch.nn.ReLU(),
torch.nn.Conv1d(32, 64, kernel_size=7, stride=2),
torch.nn.ReLU(),
torch.nn.Conv1d(64, 128, kernel_size=7, stride=2),
torch.nn.ReLU(),
torch.nn.Conv1d(128, 256, kernel_size=7, stride=2),
torch.nn.ReLU(),
GlobalAvgPool(),
torch.nn.Linear(256, 1024),
torch.nn.ReLU(),
torch.nn.Linear(1024, 1)
)
# INSERT TRAINING CODE HERE
Ok, now the set-up work is done. We have a “trained” neural network, and we’re ready to export it for inference. Enter hermes. One of the key concepts in as-a-service inference is the idea of a model repository, a local (or cloud-based) directory with a prescribed structure that hosts all the versions of all the models to be served for inference.
Maintaining this prescribed structure, as well as the configs that Triton needs to be able to map to named inputs and outputs, can be onerous and non-trivial, so hermes.quiver was built to take the headache out of building and maintaining these repositories. In this next step, we’ll build a model repository (clearing it beforehand in case you run this notebook multiple times), add a new entry to it for the network that we just build, then export the current version of this network to that repository as an ONNX binary which Triton can load and serve for inference.
# let's make sure we're starting with a fresh repo
repo_path = "model-repo"
utils.clear_repo(repo_path)
# initialize a blank model repository
repo = qv.ModelRepository(repo_path)
assert len(repo.models) == 0 # this attribute will get updated as we add models
# create a new entry in the repo for our model
model = repo.add("my-classifier", platform=qv.Platform.ONNX)
assert len(repo.models) == 1
assert model == repo.models["my-classifier"]
# now export our current version of the network to this entry.
# Since we haven't exported any versions of this model yet,
# Triton needs to know what names to give the inputs and
# outputs and what shapes to expect, so we have to specify
# them explicitly this first time.
# Note that -1 indicates variable length batch dimension.
model.export_version(
nn,
input_shapes={"hoft": (-1, NUM_IFOS, kernel_size)},
output_names=["prob"]
)
'my-classifier/1/model.onnx'
Note that our model is associated with a Config object that describes the metadata Triton requires
model.config
name: "my-classifier"
platform: "onnxruntime_onnx"
input {
name: "hoft"
data_type: TYPE_FP32
dims: -1
dims: 2
dims: 8192
}
output {
name: "prob"
data_type: TYPE_FP32
dims: -1
dims: 1
}
and that each model can be associated with multiple different versions corresponding to different weight values, or even wholesale different architectures. The only thing that matters is that the network represents the same input-to-output mapping:
model.versions
[1]
And with that, our model is ready to be served for inference! In the code below, we’ll use hermes.aeriel.serve to spin up a Singularity container on a single GPU (index 0) which will run a Triton inference service in the background. Once the serve context exits, the service and the container running it will both be spun-down.
Once the server is up and running, we’ll use hermes.aeriel.client to establish a client connection to it, then iterate through our dummy data and make requests using this connection. The requests are made asynchronously and the responses parsed in a callback thread. The parsed responses are made available in the main thread through a queue which can be succinctly accessed by InferenceClient.get(), which will return None if there are no responses to be returned.
Note that below, we’ll do our inference in batch sizes of 1. In principle, we could get better throughput by increasing that batch size (at the cost of some latency), but as we’ll see momentarily there are more pressing issues we need to address first.
# Start by spinning up an inference service.
# Note that this will print the singularity command
# used to start the service. This is unfortunately an
# unavoidable bug in singularity right now (if you don't
# want to receive an unnecessary warning instead), but it's
# probably good to note what is really happening under the
# hood anyway.
# The `instance` returned by the context is an object
# representing the running singularity container instance.
with serve("model-repo", gpus=[0]) as instance:
# Do our data generation in parallel while the server
# spins up. This will obviously be pretty fast, but
# for more complicated data generation steps this can
# be a good way to parallelize your efforts.
hoft = np.random.randn(NUM_IFOS, inference_data_size).astype("float32")
# now wait until the inference service is online and
# ready to receive requests before we attempt to connect
logger.info("Waiting for inference service to come online...")
instance.wait()
logger.info("Service ready!")
# establish a client connection to the inference service
# and infer the names and shapes of the inputs it expects
client = InferenceClient(
"localhost:8001",
model_name="my-classifier",
model_version=1 # can use -1 to imply latest version
)
# client context establishes a streaming connection to
# the inference service. Since we're not yet streaming
# updates but making requests individually, we don't
# technically need this, but it's a good habit to get into.
with client:
# now iterate through our inference timeseries at the
# prescribed stride and send these inputs to the server
# for inference.
for i in range(int(num_inferences)):
start = i * inference_stride
stop = start + kernel_size
# add a dummy dimension for the batch
kernel = hoft[:, start: stop][None]
with rate_limiter:
client.infer(kernel)
# now that we've submitted all our inference requests,
# start pulling them from the output queue as they
# become available.
results = []
while len(results) < num_inferences:
response = client.get()
if response is not None:
y, request_id, sequence_id = response
results.append(y[:, 0])
logger.info("Inference complete!")
# concatenate all the responses into a single timeseries
results = np.concatenate(results)
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest pusheena_despacito_2246
16:58:54.547 Waiting for inference service to come online...
16:59:41.155 Service ready!
16:59:49.216 Inference complete!
In principle, that does it: we served a model, ran data through it, and from looking at the timestamps on the logs it looks like we roughly hit our inference rate target. Younow have everything you need to run inference-as-a-service with hermes. But there are some considerations you might think about that could help improve both network performance and throughput. For some inspiration, let’s take a look at what our timeseries of network responses looks like:
p = plotting.plot_timeseries(results, INFERENCE_SAMPLING_RATE, KERNEL_LENGTH)
plotting.show(p)
So as we should have expected, it’s just a timeseries of more or less random data. What is worth noting about this timeseries, however, is the rate at which it is sampled: 0.25 Hz means that for shorter-duration events like binary blackhole mergers, we only get to make one prediction on each event. Surely there might be some benefit to taking predictions from multiple overlapping windows containing the same event and aggregating them somewhow. Let’s increase our inference sampling rate to, say, 4 Hz and see how this impacts our throughput.
For this next round of inference, I’m going to add a little more complication up front at the expense of slightly more elegance at inference time. Rather than iterating through responses after all our inference requests have been submitted, I’ll set up a callback up front that aggregates our responses into an array in real-time in the callback thread, then returns the array once completed.
class Callback:
def __init__(self, num_inferences):
self.y = np.zeros((num_inferences,))
def __call__(self, response, request_id, sequence_id):
self.y[request_id] = response[0, 0]
if (request_id + 1) == len(self.y):
return self.y
# reset some of our parameters with a new inference sampling rate
INFERENCE_SAMPLING_RATE = 4
inference_stride = int(SAMPLE_RATE / INFERENCE_SAMPLING_RATE)
num_inferences = (inference_data_size - kernel_size) // inference_stride + 1
kernels_per_second = int(INFERENCE_RATE * INFERENCE_SAMPLING_RATE)
rate_limiter = RateLimiter(max_calls=kernels_per_second / 20, period=0.05)
callback = Callback(num_inferences)
# from here things will look more or less the same
with serve("model-repo", gpus=[0]) as instance:
logger.info("Waiting for inference service to come online...")
instance.wait()
logger.info("Service ready!")
# instantiate client with custom callback
client = InferenceClient(
"localhost:8001",
model_name="my-classifier",
model_version=1,
callback=callback
)
with client:
for i in range(int(num_inferences)):
start = i * inference_stride
stop = start + kernel_size
kernel = hoft[:, start: stop][None]
# pass explicit request ids this time
# for the callback to use
with rate_limiter:
client.infer(kernel, request_id=i)
# now wait until the callback returns its
# filled out array to the client's queue
while True:
results = client.get()
if results is not None:
logger.info("Inference complete!")
break
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest chunky_hippo_6797
16:59:51.858 Waiting for inference service to come online...
17:00:35.590 Service ready!
17:00:51.840 Inference complete!
And now let’s take a look at how this timeseries looks
p = plotting.plot_timeseries(results, INFERENCE_SAMPLING_RATE, KERNEL_LENGTH)
plotting.show(p)
So things work with higher frequency inference, but a quick eyeball of our logs indicate that we’re now falling well short of our intended inference rate target (at time of writing, I’m seeing a time delta of ~43s, which translates to an inference rate of ~48 seconds of data / s). Why is that? Is our network capable of handling the desired number of requests per second, or is there something else going on?
For diagnosing these sorts of questions, Triton includes a secondary service which makes some inference metrics available (served at port 8002 by default). Clients can query this service for per-model information like cumulative time spent queueing and executing requests, as well as cumulative inference counts. The hermes.stillwater library has a ServerMonitor class which, in a separate process, queries this service and organizes the returned metrics from potentially many servers and writes them to a local log file.
Let’s run the same snippet from above with a monitor in place and take a look at how queuing latency evolves over time. If requests spend longer and longer queuing over time, then our network is acting as a bottleneck and we need to scale it up. Otherwise, something else is going wrong.
# set up a new directory just for our metrics
metrics_dir = Path("metrics")
metrics_dir.mkdir(exist_ok=True)
metrics_file = metrics_dir / "non-streaming_single-model.csv"
callback = Callback(num_inferences)
with serve("model-repo", gpus=[0]) as instance:
logger.info("Waiting for inference service to come online...")
instance.wait()
logger.info("Service ready!")
client = InferenceClient(
"localhost:8001",
model_name="my-classifier",
model_version=1,
callback=callback
)
monitor = ServerMonitor(
model_name="my-classifier",
ips="localhost",
filename=metrics_file,
model_version=1,
name="monitor",
rate=4
)
with client, monitor:
for i in range(int(num_inferences)):
start = i * inference_stride
stop = start + kernel_size
kernel = hoft[:, start: stop][None]
# for various reasons, the rate limiter won't work
# here, so we'll be generous and insert a sleep
# for half as long as the actual rate would be
client.infer(kernel, request_id=i)
time.sleep(0.5 / INFERENCE_SAMPLING_RATE / INFERENCE_RATE)
while True:
results = client.get()
if results is not None:
logger.info("Inference complete!")
break
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest grated_leg_5296
17:00:56.199 Waiting for inference service to come online...
17:01:37.029 Service ready!
17:01:47.234 Inference complete!
Now let’s load in the CSV that our monitor produced and take a look at the data it captured.
df = pd.read_csv(metrics_file)
df
| timestamp | ip | model | count | queue | compute_input | compute_infer | compute_output | request | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.663891e+09 | localhost | my-classifier | 87 | 587301553 | 9529 | 6885885 | 784 | 594202569 |
| 1 | 1.663891e+09 | localhost | my-classifier | 232 | 1557675965 | 12471 | 73676 | 1842 | 1557775318 |
| 2 | 1.663891e+09 | localhost | my-classifier | 331 | 2145661313 | 22275 | 7158593 | 2987 | 2152862687 |
| 3 | 1.663891e+09 | localhost | my-classifier | 465 | 2857294202 | 24292 | 146878 | 3960 | 2857492663 |
| 4 | 1.663891e+09 | localhost | my-classifier | 445 | 2551900100 | 23089 | 148376 | 3760 | 2552098128 |
| 5 | 1.663891e+09 | localhost | my-classifier | 459 | 2444907900 | 24595 | 146049 | 3668 | 2445105272 |
| 6 | 1.663891e+09 | localhost | my-classifier | 457 | 2241774583 | 24063 | 146711 | 4024 | 2241971869 |
| 7 | 1.663891e+09 | localhost | my-classifier | 392 | 1770680539 | 22459 | 148572 | 3507 | 1770877783 |
| 8 | 1.663891e+09 | localhost | my-classifier | 417 | 1734183397 | 22599 | 149888 | 3640 | 1734381239 |
| 9 | 1.663891e+09 | localhost | my-classifier | 441 | 1625196357 | 23455 | 148051 | 3652 | 1625394260 |
| 10 | 1.663891e+09 | localhost | my-classifier | 218 | 728727283 | 20618 | 153130 | 3996 | 728925976 |
| 11 | 1.663891e+09 | localhost | my-classifier | 441 | 1351631666 | 24599 | 145017 | 3758 | 1351827770 |
| 12 | 1.663891e+09 | localhost | my-classifier | 472 | 1249338573 | 23598 | 148549 | 3708 | 1249536293 |
| 13 | 1.663891e+09 | localhost | my-classifier | 293 | 674164299 | 39172 | 133387 | 3226 | 674361466 |
| 14 | 1.663891e+09 | localhost | my-classifier | 149 | 322752920 | 42381 | 150264 | 3191 | 322972130 |
| 15 | 1.663891e+09 | localhost | my-classifier | 386 | 766527339 | 44580 | 133665 | 3243 | 766730229 |
| 16 | 1.663891e+09 | localhost | my-classifier | 415 | 674953706 | 30304 | 137799 | 3660 | 675148432 |
| 17 | 1.663891e+09 | localhost | my-classifier | 457 | 545564093 | 23139 | 147243 | 3520 | 545758828 |
| 18 | 1.663891e+09 | localhost | my-classifier | 469 | 290677762 | 22557 | 148155 | 3163 | 290873176 |
| 19 | 1.663891e+09 | localhost | my-classifier | 302 | 37216007 | 14891 | 98494 | 2104 | 37345551 |
| 20 | 1.663891e+09 | localhost | my-classifier | 66 | 3553 | 4849 | 31794 | 816 | 45876 |
| 21 | 1.663891e+09 | localhost | my-classifier | 67 | 3477 | 3777 | 25752 | 725 | 37867 |
| 22 | 1.663891e+09 | localhost | my-classifier | 64 | 7363 | 4351 | 24825 | 657 | 41130 |
| 23 | 1.663891e+09 | localhost | my-classifier | 74 | 3021 | 4662 | 36788 | 714 | 49819 |
| 24 | 1.663891e+09 | localhost | my-classifier | 76 | 3660 | 4521 | 29046 | 674 | 42843 |
| 25 | 1.663891e+09 | localhost | my-classifier | 70 | 4036 | 4467 | 31954 | 700 | 45591 |
| 26 | 1.663891e+09 | localhost | my-classifier | 47 | 8670 | 4569 | 32898 | 688 | 51345 |
| 27 | 1.663891e+09 | localhost | my-classifier | 82 | 94784 | 4455 | 29302 | 748 | 133859 |
| 28 | 1.663891e+09 | localhost | my-classifier | 82 | 4372 | 4010 | 27464 | 611 | 40991 |
| 29 | 1.663891e+09 | localhost | my-classifier | 55 | 17061 | 5924 | 48886 | 935 | 79037 |
| 30 | 1.663891e+09 | localhost | my-classifier | 33 | 27969 | 4931 | 37830 | 1056 | 80273 |
| 31 | 1.663891e+09 | localhost | my-classifier | 75 | 6033 | 5107 | 33550 | 812 | 51299 |
The queue, compute_input, compute_infer, and compute_output columns all represent different steps in the inference compute of a single request. The values in each column represent the cumulative microseconds spent on each step over all of the requests computed between each ping to the metrics service, the number of which is indicated by the count column. Knowing this, let’s reframe some of the info in these columns in a more useful fashion for our purposes.
# we'll be doing this a lot, so let's record it as a function
def cleanup_df(df, t0):
df["Time since start (s)"] = df["timestamp"] - t0
df["Average queue time (us)"] = df["queue"] / df["count"]
# count all inference steps as a single metric
# of inference latency
infer_time = df[[f"compute_{i}" for i in ["input", "infer", "output"]]].sum(axis=1)
df["Average infer time (us)"] = infer_time / df["count"]
# use the number of inferences completed in an interval
# along with the inference sampling rate to put throughput
# in units of data seconds per second
df["Throughput (s' / s)"] = df["count"] / df["timestamp"].diff() / INFERENCE_SAMPLING_RATE
return df[[
"Time since start (s)",
"Throughput (s' / s)",
"Average queue time (us)",
"Average infer time (us)"
]]
df = cleanup_df(df, df.timestamp.min())
df
| Time since start (s) | Throughput (s' / s) | Average queue time (us) | Average infer time (us) | |
|---|---|---|---|---|
| 0 | 0.000000 | NaN | 6.750593e+06 | 79266.643678 |
| 1 | 0.101490 | 571.482073 | 6.714121e+06 | 379.262931 |
| 2 | 0.203203 | 813.569839 | 6.482360e+06 | 21703.489426 |
| 3 | 0.304591 | 1146.577559 | 6.144719e+06 | 376.623656 |
| 4 | 0.406262 | 1094.220999 | 5.734607e+06 | 393.764045 |
| 5 | 0.507635 | 1131.956029 | 5.326597e+06 | 379.764706 |
| 6 | 0.609230 | 1124.566687 | 4.905415e+06 | 382.490153 |
| 7 | 0.711257 | 960.526137 | 4.517042e+06 | 445.250000 |
| 8 | 0.812940 | 1025.250927 | 4.158713e+06 | 422.366906 |
| 9 | 0.914786 | 1082.516957 | 3.685253e+06 | 397.183673 |
| 10 | 1.017116 | 532.588316 | 3.342786e+06 | 815.339450 |
| 11 | 1.118904 | 1083.138178 | 3.064924e+06 | 393.138322 |
| 12 | 1.220450 | 1162.026193 | 2.646904e+06 | 372.574153 |
| 13 | 1.323272 | 712.400683 | 2.300902e+06 | 599.948805 |
| 14 | 1.441036 | 316.308677 | 2.166127e+06 | 1314.335570 |
| 15 | 1.542648 | 949.692359 | 1.985822e+06 | 470.176166 |
| 16 | 1.644157 | 1022.078312 | 1.626394e+06 | 413.886747 |
| 17 | 1.745598 | 1126.274128 | 1.193795e+06 | 380.529540 |
| 18 | 1.847353 | 1152.278619 | 6.197820e+05 | 370.735608 |
| 19 | 1.959072 | 675.803330 | 1.232318e+05 | 382.413907 |
| 20 | 2.060756 | 162.267268 | 5.383333e+01 | 567.560606 |
| 21 | 2.162252 | 165.030002 | 5.189552e+01 | 451.552239 |
| 22 | 2.263905 | 157.399162 | 1.150469e+02 | 466.140625 |
| 23 | 2.366733 | 179.912041 | 4.082432e+01 | 569.783784 |
| 24 | 2.468355 | 186.966758 | 4.815789e+01 | 450.539474 |
| 25 | 2.572194 | 168.530257 | 5.765714e+01 | 530.300000 |
| 26 | 2.674086 | 115.317916 | 1.844681e+02 | 811.808511 |
| 27 | 2.775749 | 201.646866 | 1.155902e+03 | 420.792683 |
| 28 | 2.877061 | 202.346339 | 5.331707e+01 | 391.280488 |
| 29 | 2.982984 | 129.811063 | 3.102000e+02 | 1013.545455 |
| 30 | 3.084848 | 80.990261 | 8.475455e+02 | 1327.787879 |
| 31 | 3.186673 | 184.139457 | 8.044000e+01 | 526.253333 |
Now let’s plot the queue time and throughput as a function of time, and consider what this data tells us:
p = plotting.plot_inference_metrics_vs_time(my_classifer=df)
plotting.show(p)
Counterintuitively, our queue time starts off really high, then goes down over time! Moreover, the length of the x-axis, around 3-4 seconds, doesn’t match up with the 10 second time delta we see between our logs above. What’s going on here?
It turns out that the first couple of inferences can often take substantially longer than the rest, as Triton tries to optimize the compute kernels used to actually perform inference. So while the first request is taking several seconds to process, an enormous queue builds up while we inundate the server with follow-ups. Meanwhile, the ServerMonitor doesn’t start saving metrics until the server-side metrics service indicates that at least one inference has completed, which explains why we see a much shorter observation window.
Once that first request is processed, our throughput skyrockets while the network tears through data as fast as it can (in fact this gives a decent estimate of the network’s individual throughput cap: around 1200 s’/s). However, we aren’t supplying new data to the network fast enough to keep it saturated, so the queue time and throughput both eventually settle down to some roughly constant values. Unlike the first request issue, which can be solved trivially by introducing a block until the first response has been received, this is a real problem. If we can’t get data to Triton fast enough now, with a relatively simple set up, then we’ll get no benefit from the myriad ways Triton, and inference-as-a-service more broadly, make scaling up easy.
So why is this issue occurring? Well, think about what happened when we went from INFERENCE_SAMPLING_RATE = 0.25 to INFERENCE_SAMPLING_RATE = 4: we’re continuing to send KERNEL_LENGTH-long windows of data to the inference service with each request, but now we’re doing it 16 times as often per second, with largely redundant data! It’s this network I/O that’s bottlenecking our pipeline now.
To alleviate this, we’ll need to build a model on the server-side which can cache data we’ve already sent, and use it to along with updates of new data to build the windows we need and pass them along to our downstream model. hermes.quiver has built-in support for constructing such a model ensemble by adding in a “snapshotter” model on the front-end of the server which maintains the state of the most recent input snapshot.
# add a new meta-model to the repository that organizes
# graphs of existing models to pass outputs from one
# as inputs to the next
ensemble = repo.add("streaming-classifier", platform=qv.Platform.ENSEMBLE)
# insert a snapshotter model at the front of this ensemble
# whose output will be passed to the input of my-classifier
classifier = repo.models["my-classifier"]
ensemble.add_streaming_inputs(
classifier.inputs["hoft"],
stream_size=inference_stride,
batch_size=1
)
# we know our first request takes ~12s, so make
# sure that the snapshotter will maintain a state
# for longer than this
snapshotter = repo.models["snapshotter"]
snapshotter.config.sequence_batching.max_sequence_idle_microseconds = int(25 * 10**6)
snapshotter.config.write()
# mark the output of our classifier as the output
# of the whole ensemble and then export a "version"
# of the ensemble (basically just writes its config)
ensemble.add_output(classifier.outputs["prob"])
ensemble.export_version(None)
'streaming-classifier/1/model.empty'
Now we can run inference on our streaming model and only send the updates we need for each request, rather than the entire window!
# keep our inference results from earlier to
# verify that this implementation comes out the same
nonstreaming_results = results
# quick addition to the callback to handle waiting
# for the first couple responses to come back
class BlockingCallback(Callback):
def block(self, i):
while self.y[i] == 0:
time.sleep(1e-3)
# we need to do more inferences this time, since
# the snapshot state gets initialized to 0s, so the
# first KERNEL_LENGTH * INFERENCE_SAMPLING_RATE updates
# just function to fill the snapshot out
num_inferences = inference_data_size // inference_stride
callback = BlockingCallback(num_inferences)
metrics_file = metrics_dir / "streaming_single-model.csv"
with serve("model-repo", gpus=[0]) as instance:
logger.info("Waiting for inference service to come online...")
instance.wait()
logger.info("Service ready!")
client = InferenceClient(
"localhost:8001",
model_name="streaming-classifier",
model_version=1,
callback=callback
)
monitor = ServerMonitor(
model_name="streaming-classifier",
ips="localhost",
filename=metrics_file,
model_version=1,
name="monitor"
)
with client, monitor:
for i in range(int(num_inferences)):
start = i * inference_stride
stop = start + inference_stride # note the smaller slice
kernel = hoft[:, start: stop]
# provide some additional information to
# the inference server to allow us to keep
# track of multiple different streams
with rate_limiter:
client.infer(
kernel,
request_id=i,
sequence_id=1001,
sequence_start=i == 0,
sequence_end=(i + 1) == num_inferences
)
if i < 3:
callback.block(i)
if i == 2:
logger.info("First 3 requests completed")
while True:
results = client.get()
if results is not None:
logger.info("Inference complete!")
break
# ditch some of the initial inferences, which took
# place on a 0-initialized kernel with some updates
# placed at the end of it
results = results[int(KERNEL_LENGTH * INFERENCE_SAMPLING_RATE) - 1:]
# validate that this implementation produces the
# same resluts as the original
assert (results == nonstreaming_results).all()
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest expensive_milkshake_3248
17:01:50.692 Waiting for inference service to come online...
17:02:30.022 Service ready!
17:02:37.252 First 3 requests completed
17:02:45.639 Inference complete!
p = plotting.plot_timeseries(results, INFERENCE_SAMPLING_RATE, KERNEL_LENGTH)
plotting.show(p)
df = pd.read_csv(metrics_file)
dfs = {}
for model, subdf in df.groupby("model"):
subdf = cleanup_df(subdf.reset_index(), df.timestamp.min())
dfs[model] = subdf
p = plotting.plot_inference_metrics_vs_time(**dfs)
plotting.show(p)
So this looks like a pretty stable configuration, which means we can up our inference rate until we start to hit a bottleneck. Let’s hop up to 400 and see what happens.
INFERENCE_RATE = 400
kernels_per_second = int(INFERENCE_RATE * INFERENCE_SAMPLING_RATE)
rate_limiter = RateLimiter(max_calls=kernels_per_second / 20, period=0.05)
callback = BlockingCallback(num_inferences)
metrics_file = metrics_dir / "streaming_rate-400_single-model.csv"
with serve("model-repo", gpus=[0]) as instance:
logger.info("Waiting for inference service to come online...")
instance.wait()
logger.info("Service ready!")
client = InferenceClient(
"localhost:8001",
model_name="streaming-classifier",
model_version=1,
callback=callback
)
monitor = ServerMonitor(
model_name="streaming-classifier",
ips="localhost",
filename=metrics_file,
model_version=1,
name="monitor",
rate=4,
)
with client, monitor:
for i in range(int(num_inferences)):
start = i * inference_stride
stop = start + inference_stride
kernel = hoft[:, start: stop]
with rate_limiter:
client.infer(
kernel,
request_id=i,
sequence_id=1001,
sequence_start=i == 0,
sequence_end=(i + 1) == num_inferences
)
if i < 3:
callback.block(i)
if i == 2:
logger.info("First 3 requests completed")
while True:
results = client.get()
if results is not None:
logger.info("Inference complete!")
break
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest arid_frito_0916
17:02:48.609 Waiting for inference service to come online...
17:03:27.587 Service ready!
17:03:34.698 First 3 requests completed
17:05:05.904 Inference complete!
df = pd.read_csv(metrics_file)
dfs = {}
for model, subdf in df.groupby("model"):
subdf = cleanup_df(subdf.reset_index(), df.timestamp.min())
dfs[model] = subdf
p = plotting.plot_inference_metrics_vs_time(**dfs)
plotting.show(p)
So it looks like our queue latencies are still reasonably stable, but our throughput ends up tanking early in our inference run. The likely culprit is that we’ve saturated the rate at which our client can generate requests: gRPC’s asynchronous requests are implemented with Python threads in a way that doesn’t lend itself to doing strenuous work in the main thread which generates the requests. We can improve throughput further by sending longer updates at a lower rate, then using the caching model on the server to build batches of kernels to send to the downstream model.
Let’s export a new ensemble and associated snapshotter model that expects batched updates, then run inference with that.
batched_ensemble = repo.add("batched-streaming-classifier", platform=qv.Platform.ENSEMBLE)
# right now we can only handle batches small
# enough that the update isn't longer than
# the kernel itself, so we'll use the biggest
# batch we can with that constraint
batch_size = int(KERNEL_LENGTH * INFERENCE_SAMPLING_RATE)
classifier = repo.models["my-classifier"]
batched_ensemble.add_streaming_inputs(
classifier.inputs["hoft"],
stream_size=inference_stride,
batch_size=batch_size,
name="batched-snapshotter"
)
# we know our first request takes ~12s, so make
# sure that the snapshotter will maintain a state
# for longer than this
snapshotter = repo.models["batched-snapshotter"]
snapshotter.config.sequence_batching.max_sequence_idle_microseconds = int(25 * 10**6)
snapshotter.config.write()
# mark the output of our classifier as the output
# of the whole ensemble and then export a "version"
# of the ensemble (basically just writes its config)
batched_ensemble.add_output(classifier.outputs["prob"])
batched_ensemble.export_version(None)
'batched-streaming-classifier/1/model.empty'
# make a new callback that's better equipped to
# slice out a batch of responses
class BatchedCallback(BlockingCallback):
def __init__(self, num_inferences, batch_size):
self.y = np.zeros((num_inferences * batch_size,))
self.batch_size = batch_size
def __call__(self, response, request_id, sequence_id):
start = request_id * self.batch_size
self.y[start: start + len(response)] = response[:, 0]
if (start + len(response) + 1) >= len(self.y):
return self.y
num_kernels = inference_data_size // inference_stride
num_inferences = num_kernels // batch_size
callback = BatchedCallback(num_inferences, batch_size)
batches_per_second = int(INFERENCE_RATE * INFERENCE_SAMPLING_RATE / batch_size)
rate_limiter = RateLimiter(max_calls=batches_per_second / 20, period=0.05)
metrics_file = metrics_dir / "batched-streaming_single-model.csv"
with serve("model-repo", gpus=[0]) as instance:
logger.info("Waiting for inference service to come online...")
instance.wait()
logger.info("Service ready!")
client = InferenceClient(
"localhost:8001",
model_name="batched-streaming-classifier",
model_version=1,
callback=callback
)
monitor = ServerMonitor(
model_name="batched-streaming-classifier",
ips="localhost",
filename=metrics_file,
model_version=1,
name="monitor"
)
with client, monitor:
for i in range(int(num_inferences)):
start = i * inference_stride * batch_size
stop = (i + 1) * inference_stride * batch_size
kernel = hoft[:, start: stop]
with rate_limiter:
client.infer(
kernel,
request_id=i,
sequence_id=1001,
sequence_start=i == 0,
sequence_end=(i + 1) == num_inferences
)
if i < 3:
callback.block(i)
if i == 2:
logger.info("First 3 requests completed")
while True:
results = client.get()
if results is not None:
logger.info("Inference complete!")
break
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest wobbly_hippo_8793
17:05:09.062 Waiting for inference service to come online...
17:05:52.971 Service ready!
17:05:59.821 First 3 requests completed
17:06:05.004 Inference complete!
df = pd.read_csv(metrics_file)
dfs = {}
for model, subdf in df.groupby("model"):
subdf = cleanup_df(subdf.reset_index(), df.timestamp.min())
subdf["Throughput (s' / s)"] *= batch_size
dfs[model] = subdf
p = plotting.plot_inference_metrics_vs_time(**dfs)
plotting.show(p)
Ok, so we’re back to a stable system. Let’s try pushing this a bit farther, and use a bit longer of a segment to be able to get a decent estimate of our metrics.
INFERENCE_DATA_LENGTH = 32768
INFERENCE_RATE = 4096
inference_data_size = INFERENCE_DATA_LENGTH * SAMPLE_RATE
hoft = np.random.randn(NUM_IFOS, inference_data_size).astype("float32")
num_kernels = inference_data_size // inference_stride
num_inferences = num_kernels // batch_size
callback = BatchedCallback(num_inferences, batch_size)
batches_per_second = int(INFERENCE_RATE * INFERENCE_SAMPLING_RATE / batch_size)
rate_limiter = RateLimiter(max_calls=batches_per_second / 20, period=0.05)
metrics_file = metrics_dir / "batched-streaming_rate-4096_single-model.csv"
with serve("model-repo", gpus=[0]) as instance:
logger.info("Waiting for inference service to come online...")
instance.wait()
logger.info("Service ready!")
client = InferenceClient(
"localhost:8001",
model_name="batched-streaming-classifier",
model_version=1,
callback=callback
)
monitor = ServerMonitor(
model_name="batched-streaming-classifier",
ips="localhost",
filename=metrics_file,
model_version=1,
name="monitor"
)
with client, monitor:
for i in range(int(num_inferences)):
start = i * inference_stride * batch_size
stop = (i + 1) * inference_stride * batch_size
kernel = hoft[:, start: stop]
with rate_limiter:
client.infer(
kernel,
request_id=i,
sequence_id=1001,
sequence_start=i == 0,
sequence_end=(i + 1) == num_inferences
)
if i < 3:
callback.block(i)
if i == 2:
logger.info("First 3 requests completed")
while True:
results = client.get()
if results is not None:
logger.info("Inference complete!")
break
singularity -s instance start --nv /cvmfs/singularity.opensciencegrid.org/fastml/gwiaas.tritonserver:latest joyous_signal_7729
17:06:14.401 Waiting for inference service to come online...
17:06:57.836 Service ready!
17:07:04.844 First 3 requests completed
17:07:59.569 Inference complete!
df = pd.read_csv(metrics_file)
dfs = {}
for model, subdf in df.groupby("model"):
subdf = cleanup_df(subdf.reset_index(), df.timestamp.min())
subdf["Throughput (s' / s)"] *= batch_size
dfs[model] = subdf
p = plotting.plot_inference_metrics_vs_time(**dfs)
plotting.show(p)