Trending
5 phút đọc2 tháng 6, 20261

FlashLib: Khi scikit-learn chạy trên GPU nhanh như Flash

FlashLib mang 16 thuật toán ML cổ điển lên GPU với Triton kernels. KMeans, PCA, DBSCAN, UMAP... tất cả chạy trên CUDA, API quen thuộc như sklearn.

N

Nguyễn Nhật Long

@nguyennhatlong1303

FlashLib: Khi scikit-learn chạy trên GPU nhanh như Flash

Bạn vẫn đang chạy KMeans trên CPU à?

Mình nhớ lần đầu phải cluster 5 triệu vectors 128 chiều cho một bài toán recommendation. Chạy scikit-learn trên CPU, pha cà phê, đi ăn trưa về vẫn chưa xong. Chuyển sang FAISS thì nhanh hơn nhưng API khác hoàn toàn, debug cũng mệt. Đó là câu chuyện của mấy năm trước.

Tuần rồi mình thấy một repo mới trên GitHub đang lên nhanh FlashLib của FlashML-org và nói thật, đây có thể là thứ mà nhiều anh em ML engineer Việt Nam đang cần mà chưa biết.

FlashLib là gì và tại sao nó đáng chú ý?

FlashLib là một GPU library cho các classical machine learning operators không phải deep learning, mà là những thuật toán "cơ bản" như KMeans, KNN, PCA, DBSCAN, UMAP, t-SNE, regression... Tất cả được viết lại bằng TritonCuteDSL để chạy native trên GPU.

Điểm khác biệt lớn nhất: API của nó gần như drop-in replacement cho scikit-learn. Bạn không cần học framework mới, không cần viết CUDA kernel, chỉ cần thay from sklearn.cluster import KMeans thành from flashlib import flash_kmeans là xong.

Python
1import torch
2from flashlib import flash_kmeans
3
4x = torch.randn(1_000_000, 128, device="cuda", dtype=torch.float32)
5labels, centroids, n_iter = flash_kmeans(x, n_clusters=1024, max_iters=20)

6 dòng code, 1 triệu points, chạy trên GPU. Đơn giản vậy thôi.

16 primitives, phủ gần hết workflow ML truyền thống

Điều mình thấy hay là FlashLib không chỉ làm một vài thuật toán rồi bỏ đó. Họ ship 16 high-level primitives chia theo nhóm rõ ràng:

Ngoài ra còn có một bộ low-level linear algebra primitives (cov_gemm, gram_gemm, eigh, cholqr2...) và cả một bộ multi-precision GEMM variants từ TF32, BF16, FP16 cho đến Ozaki2 INT8. Nếu bạn đang cần tối ưu matrix multiplication ở nhiều precision levels, đây là goldmine.

FamilyPrimitives
**Clustering**`flash_kmeans`, `flash_dbscan`, `flash_hdbscan`, `flash_spectral_clustering`
**Nearest Neighbors**`flash_knn`, `flash_ivf_flat` (IVF-Flat ANN)
**Decomposition**`flash_pca`, `flash_truncated_svd`
**Manifold**`flash_umap`, `flash_tsne`
**Regression**`flash_linear_regression`, `flash_ridge`, `flash_logistic_regression`
**Classification**`flash_multinomial_nb`, `flash_random_forest`
**Preprocessing**`flash_standard_scaler`

IVF-Flat: Approximate Nearest Neighbors ngay trên GPU

Một feature mình đặc biệt quan tâm là IVF-Flat approximate nearest neighbors search. Đây là bài toán cực kỳ phổ biến trong recommendation, RAG pipeline, image retrieval...

Python
1import torch
2from flashlib import IVFFlat
3
4db = torch.randn(1_000_000, 128, device="cuda")
5queries = torch.randn(10_000, 128, device="cuda")
6
7index = IVFFlat(nlist=1024, nprobe=16).fit(db)
8distances, indices = index.kneighbors(queries, n_neighbors=10)

API giống hệt sklearn pattern: fit() rồi kneighbors(). Tham số nprobe là recall knob tăng lên thì recall cao hơn nhưng chậm hơn, giảm xuống thì nhanh hơn nhưng recall thấp hơn. Theo docs, ở cùng (nlist, nprobe), recall match với FAISS / cuVS, nên bạn có thể so sánh trực tiếp.

Theo kinh nghiệm của mình, đây là điểm mà nhiều team Việt Nam hay gặp pain: muốn dùng FAISS GPU thì phải cài CUDA toolkit đúng version, compile lại, debug segfault... FlashLib dùng Triton nên installation sạch sẽ hơn nhiều chỉ cần pip install flashlib.

Informative API: Ước lượng cost trước khi chạy

Đây là feature mà mình chưa thấy library nào khác làm. Module flashlib.info cho phép bạn predict runtime, FLOPs, và HBM bytes cho bất kỳ primitive nào chỉ trong khoảng 5 microseconds trên CPU, không cần GPU.

Python
1import flashlib.info as info
2
3est = info.estimate(
4 "kmeans",
5 shape=(100_000, 64),
6 params={"K": 256, "max_iters": 20},
7 device="H200"
8)
9print(est.summary_line())

Module này không import torch, triton, hay cutlass nghĩa là bạn có thể chạy nó trong một environment hoàn toàn CPU-only. Use case rõ ràng nhất:

  • Budget pipeline trước khi chạy: Biết trước operation nào tốn bao nhiêu memory, mất bao lâu, để plan resource allocation.
  • LLM agent integration: Một AI agent có thể gọi API này để quyết định nên dùng thuật toán nào, config ra sao, mà không cần access GPU.

Điều mình thấy thú vị là mindset đằng sau: họ không chỉ muốn làm library nhanh, mà còn muốn nó predictable. Trong production, predictable quan trọng không kém gì fast.

So sánh nhanh với các lựa chọn hiện tại

Để bạn có cái nhìn tổng quan, đây là so sánh FlashLib với các alternatives phổ biến:

Điểm yếu hiện tại: FlashLib mới ở giai đoạn đầu (6 commits trên GitHub), coverage chưa bằng scikit-learn hay cuML. Nhưng 426 stars cho một repo mới cho thấy cộng đồng đang rất quan tâm.

Tiêu chíscikit-learnFAISS (GPU)cuML (RAPIDS)FlashLib
**Chạy trên**CPUCPU + GPUGPUGPU
**Cài đặt**`pip install`Build từ source hoặc condaconda (RAPIDS ecosystem)`pip install`
**API style**sklearnCustom APIsklearn-likesklearn-like + functional
**Backend**NumPy/CythonCUDA C++CUDA C++Triton + CuteDSL
**Coverage**Rất rộngANN + clusteringML algorithms16 primitives + LA ops
**Cost estimation**KhôngKhôngKhôngCó (`flashlib.info`)
**Multi-precision GEMM**KhôngKhôngHạn chếNhiều variants

Ai sẽ được lợi nhất?

Nếu bạn thuộc một trong những nhóm sau, FlashLib đáng để thử:

  • ML Engineers đang chạy preprocessing/clustering pipeline trên dataset lớn và muốn tận dụng GPU mà không đổi sang ecosystem khác.
  • MLOps/Infra teams cần estimate resource trước khi schedule jobs flashlib.info giải quyết đúng pain point này.
  • Các team đang dùng FAISS cho ANN search nhưng mệt mỏi với installation và muốn một alternative dễ cài hơn.
  • Researchers cần multi-precision GEMM để thí nghiệm numerical stability ở các precision levels khác nhau.

Những điều cần lưu ý

Mình không muốn chỉ nói toàn điều tốt. Vài điểm bạn nên cân nhắc:

  • Repo còn rất mới: 6 commits, chưa có issue nào. Điều này có thể nghĩa là chưa nhiều người dùng production, hoặc code quá tốt mình nghiêng về khả năng đầu.
  • Phụ thuộc vào Triton: Triton hiện tại chủ yếu support NVIDIA GPUs. Nếu bạn chạy AMD hay Intel GPU, chưa chắc đã dùng được.
  • Chưa rõ production stability: Với một library mới, bạn nên benchmark kỹ trên data thật trước khi đưa vào production pipeline.

Hướng đi tiếp theo

FlashLib đang ở giai đoạn mà mọi contribution đều có impact lớn. Nếu bạn quan tâm đến GPU programming với Triton, đây là một codebase rất tốt để học viết kernel cho các thuật toán ML quen thuộc, có benchmark sẵn để verify.

Theo mình, trend này sẽ tiếp tục: classical ML trên GPU không phải điều mới (cuML đã làm từ lâu), nhưng việc dùng Triton thay vì CUDA C++ để viết kernel là một bước tiến lớn về accessibility. Nhiều Python developer hơn có thể đọc, hiểu, và contribute vào kernel code.

Repo: github.com/FlashML-org/flashlib

Cài thử đi, chạy flash_kmeans trên dataset của bạn, rồi so với scikit-learn. Mình nghĩ kết quả sẽ khiến bạn ngạc nhiên.

NN

Nguyễn Nhật Long

@nguyennhatlong1303

Nguyễn Nhật Long is a Senior Frontend Engineer and Frontend Team Leader with 7 years of experience building real-time fintech platforms. Specializing in React, Next.js, TypeScript, and React Native, shipping 10+ products across Web, Mobile, Telegram Mini-Apps, and Web3.

Thấy hay? Chia sẻ cho bạn bè!