一个典型的机器学习模型服务化场景摆在面前:我们需要将一个基于 Python 的模型部署为在线推理服务,它需要根据请求实时查询特征数据,并在推理后记录结果。随着业务增长,特征数据读请求QPS预计达到数万级别,而写请求相对低频。数据库层面采用读写分离架构几乎是必然选择。同时,为了保证数据层的健壮性、类型安全和高并发处理能力,我们决定不将数据访问逻辑与模型服务耦合在同一个 Python 进程中。
这就引出了两个核心的架构决策点。
方案A:纯Python技术栈
在 BentoML 服务内部,使用 Python 的数据库访问库(如 SQLAlchemy)直接连接读写分离的数据库集群。一个服务,一个代码库,一套依赖。
- 优势: 架构简单,开发快速,技术栈统一。对于中小型项目,这是最直接的路径。
- 劣势: 在真实项目中,这套方案的劣见会很快暴露。首先,Python 的 GIL 在处理高并发 I/O 密集型任务时存在瓶颈,精细化管理数据库连接池(如 HikariCP 在 JVM 世界中的地位)的成熟方案相对匮乏。其次,数据访问逻辑与模型推理逻辑紧密耦合,任何数据模型的变更或数据库优化都需要重新部署整个模型服务,这在跨团队协作中是灾难性的。数据层团队和算法团队的迭代周期被强行绑定。
方案B:异构微服务架构
将数据访问逻辑剥离,构建一个专门的、高性能的数据访问层(Data Access Layer, DAL)。考虑到对高并发、强类型和成熟生态的需求,我们选择 Scala 构建此服务。BentoML 推理服务通过 RPC 与 Scala DAL 通信。
- 优势:
- 专业分工: 充分利用 JVM 生态在高并发 I/O 和数据处理上的优势。Scala 的强类型系统和函数式特性可以构建极其健壮和可维护的数据服务。
- 关注点分离: 算法工程师专注于
service.py
中的模型逻辑,数据工程师则维护 Scala DAL,两者可以独立迭代、部署和扩缩容。 - 性能与稳定性: Scala 服务可以利用 Akka/ZIO 等框架处理海量并发连接,并通过 HikariCP 等顶级连接池保证数据库交互的稳定高效。
- 劣势:
- 架构复杂性: 引入了新的服务,带来了网络延迟、服务发现、序列化开销等问题。
- 新的核心挑战:
- 跨语言可观测性: 一个请求从 Python 服务发起,经过网络,进入 Scala 服务,再访问数据库。如何构建一条完整的分布式链路追踪,定位性能瓶颈和错误?
- 读写分离策略管理: DAL 必须智能地将请求路由到主库或从库。更棘手的是,如何处理主从复制延迟导致的数据一致性问题?
经过权衡,对于一个需要长期演进的生产级系统,方案B的优势远大于其复杂性。本文的重点,就是深入探讨如何实现方案B,并着重解决其引入的两个核心挑战。我们将使用 gRPC 作为通信协议,因为它基于 HTTP/2,性能优越,且通过 Protocol Buffers 提供了跨语言的强类型契约。BentoML 底层可基于 Tornado 运行,其异步特性与 grpc.aio
完美契合。
架构与调用链路设计
在深入代码之前,我们先用图表明确整个系统的结构和请求流程。
sequenceDiagram participant Client participant BentoML_Service as BentoML (Python/Tornado) participant Scala_DAL as Scala DAL (Akka/gRPC) participant DB_Primary as Primary DB participant DB_Replica as Replica DB Client->>+BentoML_Service: HTTP POST /predict Note right of BentoML_Service: 1. OTel Tracer starts a new span BentoML_Service->>+Scala_DAL: gRPC GetFeatures(user_id) Note right of BentoML_Service: 2. Trace context injected into gRPC metadata Note left of Scala_DAL: 3. OTel Interceptor extracts trace context,
starts a new child span Scala_DAL->>+DB_Replica: SQL SELECT ... FROM features WHERE ... DB_Replica-->>-Scala_DAL: Feature Data Scala_DAL-->>-BentoML_Service: FeaturesResponse BentoML_Service->>BentoML_Service: model.predict(features) Note right of BentoML_Service: 4. Local computation BentoML_Service->>+Scala_DAL: gRPC LogPrediction(prediction_result) Scala_DAL->>+DB_Primary: SQL INSERT INTO predictions ... DB_Primary-->>-Scala_DAL: Ack Scala_DAL-->>-BentoML_Service: LogResponse BentoML_Service-->>-Client: HTTP 200 OK
这个流程清晰地展示了两个RPC调用:一个读(GetFeatures
),一个写(LogPrediction
)。我们的核心实现将围绕这个流程展开。
核心实现:构建跨语言桥梁
1. 定义服务契约 (Protocol Buffers)
一切从定义接口开始。proto
文件是 Python 服务和 Scala 服务之间的契约。
feature_service.proto
:
syntax = "proto3";
package com.mycorp.dal;
// The feature service definition.
service FeatureService {
// Retrieves features for a given entity
rpc GetFeatures (FeatureRequest) returns (FeatureResponse);
// Logs a prediction result
rpc LogPrediction (LogRequest) returns (LogResponse);
}
message FeatureRequest {
string entity_id = 1;
}
message FeatureResponse {
map<string, float> features = 1;
bool found = 2;
}
message LogRequest {
string entity_id = 1;
string model_version = 2;
map<string, float> prediction_scores = 3;
}
message LogResponse {
bool success = 1;
string message = 2;
}
这份契约非常明确:GetFeatures
用于读,LogPrediction
用于写。这是后续在 Scala DAL 中实现读写分离路由的依据。
2. Scala 数据访问层 (DAL) 实现
我们将使用 Akka gRPC、Slick 进行数据库交互,并集成 OpenTelemetry Java Agent 实现无侵入的可观测性。
项目结构 (sbt):
- build.sbt
- project/plugins.sbt
- src/main/
- proto/feature_service.proto
- scala/com/mycorp/dal/
- DatabaseConfig.scala
- FeatureServiceImpl.scala
- DataAccessServer.scala
- resources/
- application.conf
- logback.xml
build.sbt
关键依赖:
// build.sbt
val AkkaVersion = "2.6.20"
val AkkaGrpcVersion = "2.3.2"
libraryDependencies ++= Seq(
"com.typesafe.akka" %% "akka-actor-typed" % AkkaVersion,
"com.typesafe.akka" %% "akka-stream" % AkkaVersion,
"com.typesafe.akka" %% "akka-discovery" % AkkaVersion,
"com.typesafe.akka" %% "akka-pki" % AkkaVersion,
"com.lightbend.akka.grpc" %% "akka-grpc-runtime" % AkkaGrpcVersion,
"com.typesafe.slick" %% "slick" % "3.4.1",
"com.typesafe.slick" %% "slick-hikaricp" % "3.4.1",
"org.postgresql" % "postgresql" % "42.5.4", // Or your DB driver
"ch.qos.logback" % "logback-classic" % "1.4.5"
)
// Enable Akka gRPC code generation
enablePlugins(AkkaGrpcPlugin)
数据库配置与读写分离路由 (DatabaseConfig.scala
):
这是整个 DAL 的核心。我们不使用单一的数据库连接,而是定义一个主库和一个从库的配置,并创建一个简单的路由器来分发数据库操作。
// src/main/scala/com/mycorp/dal/DatabaseConfig.scala
package com.mycorp.dal
import slick.jdbc.JdbcBackend.Database
import com.typesafe.config.ConfigFactory
object DatabaseConfig {
private val config = ConfigFactory.load()
// Primary database configuration for writes
val primaryDb: Database = Database.forConfig("database.primary", config)
// Replica database configuration for reads
val replicaDb: Database = Database.forConfig("database.replica", config)
// A simple enum to represent the database role
sealed trait DbRole
case object Primary extends DbRole
case object Replica extends DbRole
/**
* Selects the database based on the role.
* This is the core of our read-write splitting logic.
* In a real-world scenario, this could be more complex, e.g., round-robin
* over multiple replicas.
* @param role The desired database role (Primary or Replica)
* @return The corresponding Slick Database instance
*/
def getDb(role: DbRole): Database = role match {
case Primary => primaryDb
case Replica => replicaDb
}
def closeDbs(): Unit = {
primaryDb.close()
replicaDb.close()
}
}
application.conf
配置:
# src/main/resources/application.conf
database {
primary {
profile = "slick.jdbc.PostgresProfile$"
db {
dataSourceClass = "org.postgresql.ds.PGSimpleDataSource"
properties = {
serverName = "primary-db-host"
portNumber = "5432"
databaseName = "mydatabase"
user = "user"
password = "password"
}
numThreads = 10 // Connection pool size
// HikariCP specific settings
connectionTimeout = 30000
validationTimeout = 5000
}
}
replica {
profile = "slick.jdbc.PostgresProfile$"
db {
dataSourceClass = "org.postgresql.ds.PGSimpleDataSource"
properties = {
serverName = "replica-db-host"
portNumber = "5432"
databaseName = "mydatabase"
user = "user"
password = "password"
// Key setting for replicas to ensure we don't accidentally write
readOnly = true
}
numThreads = 20 // Can have a larger pool for read replicas
connectionTimeout = 30000
}
}
}
// Akka gRPC server config
akka.grpc.server {
service-port-name = "feature-service"
}
gRPC 服务实现 (FeatureServiceImpl.scala
):
这里我们将 proto
定义的接口与数据库逻辑连接起来。注意 getFeatures
使用 Replica
而 logPrediction
使用 Primary
。
// src/main/scala/com/mycorp/dal/FeatureServiceImpl.scala
package com.mycorp.dal
import scala.concurrent.{ExecutionContext, Future}
import akka.actor.typed.ActorSystem
import org.slf4j.LoggerFactory
import slick.jdbc.PostgresProfile.api._
class FeatureServiceImpl(implicit ec: ExecutionContext, system: ActorSystem[_]) extends FeatureService {
private val log = LoggerFactory.getLogger(classOf[FeatureServiceImpl])
// Define table mapping (Slick's way)
private class Features(tag: Tag) extends Table[(String, String, Float)](tag, "features") {
def entityId = column[String]("entity_id", O.PrimaryKey)
def featureName = column[String]("feature_name", O.PrimaryKey)
def featureValue = column[Float]("feature_value")
def * = (entityId, featureName, featureValue)
}
private val features = TableQuery[Features]
private class Predictions(tag: Tag) extends Table[(String, String, String, Float)](tag, "predictions") {
def entityId = column[String]("entity_id")
def modelVersion = column[String]("model_version")
def scoreName = column[String]("score_name")
def scoreValue = column[Float]("score_value")
def * = (entityId, modelVersion, scoreName, scoreValue)
}
private val predictions = TableQuery[Predictions]
override def getFeatures(in: FeatureRequest): Future[FeatureResponse] = {
log.info(s"Fetching features for entity: ${in.entityId}")
val db = DatabaseConfig.getDb(DatabaseConfig.Replica) // USE REPLICA
val query = features.filter(_.entityId === in.entityId).result
db.run(query).map { featureRows =>
if (featureRows.isEmpty) {
FeatureResponse(found = false)
} else {
val featureMap = featureRows.map { case (_, name, value) => name -> value }.toMap
FeatureResponse(features = featureMap, found = true)
}
}.recover {
case ex: Exception =>
log.error(s"Failed to get features for ${in.entityId}", ex)
// In a real system, you might return a specific gRPC error code
throw new RuntimeException("Database query failed")
}
}
override def logPrediction(in: LogRequest): Future[LogResponse] = {
log.info(s"Logging prediction for entity: ${in.entityId}")
val db = DatabaseConfig.getDb(DatabaseConfig.Primary) // USE PRIMARY
val rowsToInsert = in.predictionScores.map { case (name, value) =>
(in.entityId, in.modelVersion, name, value)
}.toSeq
val action = predictions ++= rowsToInsert
db.run(action.transactionally).map {
case Some(_) | None => // Slick returns Option[Int] for batch inserts
log.info(s"Successfully logged ${rowsToInsert.size} scores for ${in.entityId}")
LogResponse(success = true)
}.recover {
case ex: Exception =>
log.error(s"Failed to log prediction for ${in.entityId}", ex)
LogResponse(success = false, message = ex.getMessage)
}
}
}
启动服务并集成可观测性:
我们将使用 OpenTelemetry Java Agent,它通过字节码注入的方式自动完成大部分 instrument 工作,包括 Akka-HTTP/gRPC 和 JDBC。
启动脚本 run.sh
:
#!/bin/bash
OTEL_AGENT_PATH="./opentelemetry-javaagent.jar"
OTEL_SERVICE_NAME="scala-dal-service"
OTEL_EXPORTER_OTLP_ENDPOINT="http://jaeger-collector:4317" # Or your OTel collector endpoint
java -javaagent:${OTEL_AGENT_PATH} \
-Dotel.service.name=${OTEL_SERVICE_NAME} \
-Dotel.exporter.otlp.endpoint=${OTEL_EXPORTER_OTLP_ENDPOINT} \
-jar /path/to/your/fat.jar com.mycorp.dal.DataAccessServer
这几乎是实现 Scala 端可观测性的全部工作。Agent 会自动侦测到 gRPC 请求的元数据中是否包含 W3C Trace Context,如果包含,则会创建一个子 Span,从而将链路连接起来。
3. BentoML 推理服务实现
现在转向 Python 端。我们需要一个 BentoML 服务,它能作为 gRPC 客户端与 Scala DAL 通信。
项目结构:
- bentofile.yaml
- service.py
- requirements.txt
- generated_grpc/ # Python code generated from .proto file
首先,生成 gRPC Python 代码:
python -m grpc_tools.protoc \
-I. \
--python_out=./generated_grpc \
--grpc_python_out=./generated_grpc \
feature_service.proto
requirements.txt
:
bentoml
grpcio
grpcio-tools
opentelemetry-api
opentelemetry-sdk
opentelemetry-exporter-otlp-proto-grpc
opentelemetry-instrumentation-bentoml
opentelemetry-instrumentation-grpc
# Plus your model's dependencies, e.g., scikit-learn
service.py
:
# service.py
import bentoml
from bentoml.io import JSON
import numpy as np
import asyncio
import grpc
from generated_grpc import feature_service_pb2, feature_service_pb2_grpc
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.instrumentation.bentoml import BentoMLInstrumentor
from opentelemetry.instrumentation.grpc import GrpcInstrumentorClient
# --- 1. Observability Setup ---
# This setup should be done once when the process starts.
def setup_observability():
resource = Resource(attributes={"service.name": "bentoml-inference-service"})
provider = TracerProvider(resource=resource)
# Configure the OTLP exporter to send traces to a collector
otlp_exporter = OTLPSpanExporter(endpoint="http://jaeger-collector:4317", insecure=True)
processor = BatchSpanProcessor(otlp_exporter)
provider.add_span_processor(processor)
# Sets the global default tracer provider
trace.set_tracer_provider(provider)
# Auto-instrumentation for BentoML and gRPC
BentoMLInstrumentor().instrument()
GrpcInstrumentorClient().instrument()
setup_observability()
tracer = trace.get_tracer(__name__)
# --- 2. gRPC Client Setup ---
# A single, reusable gRPC channel is crucial for performance.
# We use an async channel because BentoML's API server (Tornado) is async.
DAL_SERVICE_ADDRESS = "scala-dal-service:8080"
# This channel should be created once and reused across requests.
# BentoML's @bentoml.on_startup hook is a good place for this, but for simplicity,
# we define it globally here. A more robust solution would manage its lifecycle.
_grpc_channel = grpc.aio.insecure_channel(DAL_SERVICE_ADDRESS)
_grpc_stub = feature_service_pb2_grpc.FeatureServiceStub(_grpc_channel)
# --- 3. A Dummy Model for Demonstration ---
# In a real scenario, you'd load a trained model.
class MySimpleModel:
def predict(self, features: dict) -> float:
# A simple linear combination
score = 0.5 * features.get('feature_a', 0) + 0.3 * features.get('feature_b', 0)
return score
model = MySimpleModel()
# --- 4. BentoML Service Definition ---
@bentoml.service
class InferenceService:
@bentoml.api
async def predict(self, data: JSON) -> JSON:
entity_id = data.get("entity_id")
if not entity_id:
return {"error": "entity_id is required"}
# The OpenTelemetry instrumentor will automatically create a span here.
# --- RPC Call 1: Get Features ---
try:
feature_request = feature_service_pb2.FeatureRequest(entity_id=entity_id)
# The gRPC instrumentor injects the trace context into the request metadata here.
feature_response = await _grpc_stub.GetFeatures(feature_request)
if not feature_response.found:
return {"error": f"Features not found for entity {entity_id}"}
features = dict(feature_response.features)
except grpc.aio.AioRpcError as e:
# Proper error handling is critical in a production system.
# Here we just log and return an error.
print(f"gRPC call to GetFeatures failed: {e}")
return {"error": "Failed to communicate with data service"}
# --- Model Inference ---
# You can create a manual span for more granularity if needed.
with tracer.start_as_current_span("model_prediction") as span:
prediction_score = model.predict(features)
span.set_attribute("entity_id", entity_id)
span.set_attribute("prediction_score", prediction_score)
# --- RPC Call 2: Log Prediction ---
try:
log_request = feature_service_pb2.LogRequest(
entity_id=entity_id,
model_version="v1.2.3",
prediction_scores={"main_score": prediction_score}
)
# The context is propagated again for this second call.
await _grpc_stub.LogPrediction(log_request)
except grpc.aio.AioRpcError as e:
# In a production system, logging failure might be non-critical.
# We might just log the error and continue, instead of failing the request.
print(f"gRPC call to LogPrediction failed: {e}")
# Do not return an error to the client.
return {"entity_id": entity_id, "score": prediction_score}
bentofile.yaml
:
service: "service:InferenceService"
labels:
owner: ml-team
project: real-time-prediction
include:
- "*.py"
- "generated_grpc/"
python:
requirements_txt: "./requirements.txt"
当这个 BentoML 服务运行时,opentelemetry-instrumentation-grpc
会自动拦截 _grpc_stub.GetFeatures
和 _grpc_stub.LogPrediction
的调用,从当前 Span(由 BentoMLInstrumentor
创建)中提取 Trace Context,并将其注入到 gRPC 请求的元数据中。这就是跨语言链路得以延续的关键。
架构的局限性与未来迭代路径
这个架构虽然解决了关注点分离和跨语言可观测性的问题,但它并非银弹。在生产环境中,很快会遇到新的挑战。
首先,读写分离带来的数据一致性问题被 DAL 封装,但并未解决。想象一个场景:一个新用户注册,其特征被写入主库。如果紧接着的毫秒级内,一个推理请求到达,GetFeatures
调用被路由到尚未同步该新用户数据的从库,就会导致推理失败。当前架构对此毫无办法。解决路径包括:
- 会话粘性: 在更上游的网关层或 BentoML 服务层,将来自同一用户的请求在短时间窗口(例如5秒)内全部路由到主库,代价是增加了主库的读压力。
- 缓存: 在 Scala DAL 中引入 Redis 等缓存。写操作不仅写入主库,还同步更新缓存。读操作优先查缓存,未命中再查从库。这增加了架构的复杂性。
- 应用层补偿: 在获取特征失败时,应用可以短暂等待后重试,寄希望于此时主从同步已经完成。
其次,gRPC 调用的容错性。网络是不可靠的。GetFeatures
调用失败应该导致整个请求失败,但 LogPrediction
失败可能只需要记录日志,而不应影响给客户端的成功响应。这就需要在 Python 客户端代码中实现更精细的重试和降级逻辑。例如,使用 tenacity
库对 GetFeatures
进行带指数退避的重试,而对 LogPrediction
则采用”fire-and-forget”模式,或者将其发送到消息队列进行异步处理。
最后,服务发现与负载均衡。目前我们硬编码了 DAL 的地址。在生产环境中,Scala DAL 会部署多个实例。BentoML 服务需要通过服务发现机制(如 Consul, Kubernetes Service)找到健康的 DAL 实例,并进行客户端负载均衡。grpc-python
本身支持与多种负载均衡策略集成,但这需要额外的配置和基础设施支持。