譯者 | 布加迪
審校 | 重樓
Redis 是一款開(kāi)源內(nèi)存數(shù)據(jù)結(jié)構(gòu)存儲(chǔ)系統(tǒng),是機(jī)器學(xué)習(xí)應(yīng)用領(lǐng)域中緩存的優(yōu)選。它的速度、耐用性以及支持各種數(shù)據(jù)結(jié)構(gòu)使其成為滿(mǎn)足實(shí)時(shí)推理任務(wù)的高吞吐量需求的理想選擇。
我們?cè)诒窘坛讨袑⑻接慠edis緩存在機(jī)器學(xué)習(xí)工作流程中的重要性。我們將演示如何使用FastAPI和Redis構(gòu)建一個(gè)強(qiáng)大的機(jī)器學(xué)習(xí)應(yīng)用程序。本教程介紹如何在Windows上安裝Redis、在本地運(yùn)行Redis以及如何將其集成到機(jī)器學(xué)習(xí)項(xiàng)目中。最后,我們將通過(guò)發(fā)送重復(fù)請(qǐng)求和獨(dú)特請(qǐng)求來(lái)測(cè)試該應(yīng)用程序,以驗(yàn)證Redis緩存系統(tǒng)正常運(yùn)行。
為什么在機(jī)器學(xué)習(xí)中使用Redis緩存?
在當(dāng)今快節(jié)奏的數(shù)字環(huán)境中,用戶(hù)期望機(jī)器學(xué)習(xí)應(yīng)用程序能夠立即獲得結(jié)果。比如說(shuō),使用推薦模型向用戶(hù)推薦產(chǎn)品的電商平臺(tái)。如果實(shí)施Redis來(lái)緩存重復(fù)請(qǐng)求,該平臺(tái)就可以顯著縮短響應(yīng)時(shí)間。
當(dāng)用戶(hù)請(qǐng)求產(chǎn)品推薦時(shí),系統(tǒng)先檢查該請(qǐng)求是否已被緩存。如果已緩存,則在幾微秒內(nèi)返回緩存的響應(yīng),從而提供無(wú)縫的體驗(yàn)。如果沒(méi)有緩存,模型就處理該請(qǐng)求,生成推薦,并將結(jié)果存儲(chǔ)在Redis中供將來(lái)的請(qǐng)求使用。這種方法不僅提高了用戶(hù)滿(mǎn)意度,還優(yōu)化了服務(wù)器資源,使模型能夠高效地處理更多請(qǐng)求。
使用Redis構(gòu)建網(wǎng)絡(luò)釣魚(yú)電子郵件分類(lèi)應(yīng)用程序
我們在本項(xiàng)目中將構(gòu)建一個(gè)網(wǎng)絡(luò)釣魚(yú)電子郵件分類(lèi)應(yīng)用程序。整個(gè)過(guò)程包括加載和處理來(lái)自Kaggle的數(shù)據(jù)集,使用處理后的數(shù)據(jù)訓(xùn)練機(jī)器學(xué)習(xí)模型,評(píng)估其性能,保存經(jīng)過(guò)訓(xùn)練的模型,最后構(gòu)建帶有Redis集成機(jī)制的FastAPI應(yīng)用程序。
1. 設(shè)置
- 從Kaggle下載網(wǎng)絡(luò)釣魚(yú)電子郵件檢測(cè)數(shù)據(jù)集,并將其放入到data/目錄。
- 首先你需要安裝Redis。在終端中運(yùn)行以下命令安裝Redis Python客戶(hù)程序:
pip install redis
- 如果你使用Windows系統(tǒng),且未安裝Windows Subsystem for Linux(WSL),請(qǐng)按照微軟指南啟用WSL,并從微軟商店安裝Linux發(fā)行版(比如Ubuntu)。
- WSL設(shè)置完成后,打開(kāi)WSL終端,并執(zhí)行以下命令安裝Redis:
sudo apt update
sudo apt install redis-server
- 要啟動(dòng)Redis服務(wù)器,請(qǐng)運(yùn)行:
sudo service redis-server start
你應(yīng)該會(huì)看到一條確認(rèn)消息,表明“redis-server”已成功啟動(dòng)。
2. 模型訓(xùn)練
訓(xùn)練腳本可加載數(shù)據(jù)集、處理數(shù)據(jù)、訓(xùn)練模型并將其保存在本地。
import joblib
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
def main():
# Load dataset
df = pd.read_csv("data/Phishing_Email.csv") # adjust the path as necessary
# Assume dataset has columns "text" and "label"
X = df["Email Text"].fillna("")
y = df["Email Type"]
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Create a pipeline with TF-IDF and Logistic Regression
pipeline = Pipeline(
[
("tfidf", TfidfVectorizer(stop_words="english")),
("clf", LogisticRegression(solver="liblinear")),
]
)
# Train the model
pipeline.fit(X_train, y_train)
# Save the trained model to a file
joblib.dump(pipeline, "phishing_model.pkl")
print("Model trained and saved as phishing_model.pkl")
if __name__ == "__main__":
main()
python train.py
Model trained and saved as phishing_model.pkl
3. 模型評(píng)估
評(píng)估腳本可加載數(shù)據(jù)集和保存的模型文件以執(zhí)行模型評(píng)估。
import pandas as pd
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
import joblib
def main():
# Load dataset
df = pd.read_csv("data/Phishing_Email.csv") # adjust the path as necessary
# Assume dataset has columns "text" and "label"
X = df["Email Text"].fillna("")
y = df["Email Type"]
# Split the dataset
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Load the trained model
model = joblib.load("phishing_model.pkl")
# Make predictions on the test set
y_pred = model.predict(X_test)
# Evaluate the model
print("Accuracy: ", accuracy_score(y_test, y_pred))
print("Classification Report:")
print(classification_report(y_test, y_pred))
if __name__ == "__main__":
main()
結(jié)果近乎完美,F1分?jǐn)?shù)也非常出色。
python validate.py
Accuracy: 0.9723860589812332
Classification Report:
precision recall f1-score support
Phishing Email 0.96 0.97 0.96 1457
Safe Email 0.98 0.97 0.98 2273
accuracy 0.97 3730
macro avg 0.97 0.97 0.97 3730
weighted avg 0.97 0.97 0.97 3730
4. 使用Redis提供模型服務(wù)
為了提供模型服務(wù),我們將使用FastAPI創(chuàng)建REST API,并集成Redis以緩存預(yù)測(cè)。
import asyncio
import json
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
import redis.asyncio as redis
# Create an asynchronous Redis client (make sure Redis is running on localhost:6379)
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_respnotallow=True)
# Load the trained model (synchronously)
model = joblib.load("phishing_model.pkl")
app = FastAPI()
# Define the request and response data models
class PredictionRequest(BaseModel):
text: str
class PredictionResponse(BaseModel):
prediction: str
probability: float
@app.post("/predict", response_model=PredictionResponse)
async def predict_email(data: PredictionRequest):
# Use the email text as a cache key
cache_key = f"prediction:{data.text}"
cached = await redis_client.get(cache_key)
if cached:
return json.loads(cached)
# Run model inference in a thread to avoid blocking the event loop
pred = await asyncio.to_thread(model.predict, [data.text])
prob = await asyncio.to_thread(lambda: model.predict_proba([data.text])[0].max())
result = {"prediction": str(pred[0]), "probability": float(prob)}
# Cache the result for 1 hour (3600 seconds)
await redis_client.setex(cache_key, 3600, json.dumps(result))
return result
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
python serve.py
INFO: Started server process [17640]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
你可以通過(guò)訪問(wèn)URL來(lái)查看REST API 文檔。
本項(xiàng)目的源代碼、配置文件、模型和數(shù)據(jù)集可以在kingabzpro/Redis-ml-project GitHub代碼庫(kù)中找到。如果你在運(yùn)行上述代碼時(shí)遇到任何問(wèn)題,可隨時(shí)參閱。
Redis緩存在機(jī)器學(xué)習(xí)應(yīng)用中的工作原理
下面逐步解釋Redis緩存在我們的機(jī)器學(xué)習(xí)應(yīng)用程序中的運(yùn)作方式,并附加一張流程圖加以說(shuō)明:
- 客戶(hù)程序提交輸入數(shù)據(jù),請(qǐng)求機(jī)器學(xué)習(xí)模型進(jìn)行預(yù)測(cè)。
- 系統(tǒng)根據(jù)輸入數(shù)據(jù)生成獨(dú)特的標(biāo)識(shí)符,以檢查預(yù)測(cè)是否已存在。
- 系統(tǒng)使用生成的鍵查詢(xún)Redis緩存,以查找先前存儲(chǔ)的預(yù)測(cè)。
A.如果找到緩存的預(yù)測(cè),則檢索該預(yù)測(cè)并以JSON響應(yīng)的形式返回。
B.如果沒(méi)有找到緩存的預(yù)測(cè),則將輸入數(shù)據(jù)傳遞給機(jī)器學(xué)習(xí)模型以生成新的預(yù)測(cè)。
- 新生成的預(yù)測(cè)存儲(chǔ)在Redis緩存中供將來(lái)使用。
- 最終結(jié)果以JSON格式返回給客戶(hù)程序。
測(cè)試網(wǎng)絡(luò)釣魚(yú)電子郵件分類(lèi)應(yīng)用程序
構(gòu)建完網(wǎng)絡(luò)釣魚(yú)電子郵件分類(lèi)應(yīng)用程序后,就可以測(cè)試其功能了。我們在本節(jié)中將使用 `cURL` 命令發(fā)送多封電子郵件并分析響應(yīng)來(lái)評(píng)估該應(yīng)用程序。此外,我們將驗(yàn)證Redis數(shù)據(jù)庫(kù),以確保緩存系統(tǒng)正常運(yùn)行。
使用CURL命令測(cè)試 API
為了測(cè)試API,我們將向`/predict`端點(diǎn)發(fā)送五個(gè)請(qǐng)求。其中三個(gè)請(qǐng)求包含獨(dú)特的電子郵件文本,另外兩個(gè)請(qǐng)求是之前發(fā)送的電子郵件的復(fù)制版本。這將使我們能夠驗(yàn)證預(yù)測(cè)準(zhǔn)確性和緩存機(jī)制。
echo "\n===== Testing API Endpoint with 5 Requests =====\n"
# First unique email
echo "\n----- Request 1 (First unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'
# Second unique email
echo "\n\n----- Request 2 (Second unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'
# First duplicate (same as first email)
echo "\n\n----- Request 3 (Duplicate of first email - should be cached) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'
# Third unique email
echo "\n\n----- Request 4 (Third unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "congratulations you have won a free iphone, click here to claim your prize now before it expires"
}'
# Second duplicate (same as second email)
echo "\n\n----- Request 5 (Duplicate of second email - should be cached) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'
echo "\n\n===== Test Complete =====\n"
echo "Now run 'python check_redis.py' to verify the Redis cache entries"
運(yùn)行上述腳本時(shí),API應(yīng)該返回每封電子郵件的預(yù)測(cè)結(jié)果。對(duì)于重復(fù)的請(qǐng)求,響應(yīng)應(yīng)該從Redis緩存中加以檢索,以確保更快的響應(yīng)時(shí)間。
sh test.sh
\n===== Testing API Endpoint with 5 Requests =====\n
\n----- Request 1 (First unique email) -----
{"prediction":"Safe Email","probability":0.7791625553383463}\n\n----- Request 2 (Second unique email) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}\n\n----- Request 3 (Duplicate of first email - should be cached) -----
{"prediction":"Safe Email","probability":0.7791625553383463}\n\n----- Request 4 (Third unique email) -----
{"prediction":"Phishing Email","probability":0.9169092144856761}\n\n----- Request 5 (Duplicate of second email - should be cached) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}\n\n===== Test Complete =====\n
Now run 'python check_redis.py' to verify the Redis cache entries
驗(yàn)證Redis緩存
為了確認(rèn)緩存系統(tǒng)正常運(yùn)行,我們將使用Python腳本`check_redis.py`來(lái)檢查Redis數(shù)據(jù)庫(kù)。該腳本檢索緩存的預(yù)測(cè)結(jié)果,并將其以表格形式顯示出來(lái)。
import redis
import json
from tabulate import tabulate
def main():
# Connect to Redis (ensure Redis is running on localhost:6379)
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_respnotallow=True)
# Retrieve all keys that start with "prediction:"
keys = redis_client.keys("prediction:*")
total_entries = len(keys)
print(f"Total number of cached prediction entries: {total_entries}\n")
table_data = []
# Process only the first 5 entries
for key in keys[:5]:
# Remove the 'prediction:' prefix to get the original email text
email_text = key.replace("prediction:", "", 1)
# Retrieve the cached value
value = redis_client.get(key)
try:
data = json.loads(value)
except json.JSONDecodeError:
data = {}
prediction = data.get("prediction", "N/A")
# Display only the first 7 words of the email text
words = email_text.split()
truncated_text = " ".join(words[:7]) + ("..." if len(words) > 7 else "")
table_data.append([truncated_text, prediction])
# Print table using tabulate (only two columns now)
headers = ["Email Text (First 7 Words)", "Prediction"]
print(tabulate(table_data, headers=headers, tablefmt="pretty"))
if __name__ == "__main__":
main()
當(dāng)你運(yùn)行check_redis.py腳本時(shí),它會(huì)以表格形式顯示緩存條目數(shù)量和已緩存的預(yù)測(cè)結(jié)果。
python check_redis.py
Total number of cached prediction entries: 3
+--------------------------------------------------+----------------+
| Email Text (First 7 Words) | Prediction |
+--------------------------------------------------+----------------+
| congratulations you have won a free iphone,... | Phishing Email |
| urgent action required: your account has been... | Phishing Email |
| todays floor meeting you may get a... | Safe Email |
+--------------------------------------------------+----------------+
結(jié)語(yǔ)
通過(guò)使用多個(gè)請(qǐng)求測(cè)試釣魚(yú)郵件分類(lèi)應(yīng)用程序,我們成功地演示了該API能夠準(zhǔn)確識(shí)別釣魚(yú)郵件,同時(shí)還能使用Redis高效地緩存重復(fù)請(qǐng)求。這種緩存機(jī)制通過(guò)減少重復(fù)輸入的冗余計(jì)算顯著提升了性能,這在API處理龐大流量的實(shí)際應(yīng)用場(chǎng)景中尤其大有助益。
雖然這是一個(gè)比較簡(jiǎn)單的機(jī)器學(xué)習(xí)模型,但在處理更龐大、更復(fù)雜的模型(比如圖像識(shí)別)時(shí),緩存的優(yōu)勢(shì)來(lái)得更為明顯。比如說(shuō),如果你在部署一個(gè)大規(guī)模圖像分類(lèi)模型,緩存頻繁處理輸入的預(yù)測(cè)結(jié)果就可以節(jié)省大量計(jì)算資源,并顯著縮短響應(yīng)時(shí)間。
原文標(biāo)題:Accelerate Machine Learning Model Serving with FastAPI and Redis Caching,作者:Abid Ali Awan