Skip to content

PATTERN Cited by 1 source

Batched matmul for pairwise similarity

Problem

A service needs to score every candidate against every reference item via a per-pair similarity function — cosine, dot product, Euclidean, etc. The naive implementation is a nested loop over M candidates and N reference items computing M × N independent per-pair kernels. Even when each kernel is cheap, the pattern is expensive at scale:

  • Repeated lookups of the same reference embedding across candidates defeat cache locality.
  • Scattered memory access across non-contiguous embedding storage prevents SIMD vectorisation.
  • Per-pair function-call overhead dominates when D (the embedding dimension) is small-to-moderate.
  • The JIT can't hoist much — each pair is its own tiny computation.

Solution

Reshape the M × N per-pair computation into a single matrix multiply:

  1. Stack all candidate embeddings into a dense matrix A of shape M × D.
  2. Stack all reference embeddings into a dense matrix B of shape N × D.
  3. Unit-normalise both (‖row‖ = 1) if the kernel is cosine similarity; normalisation is now amortised once per matrix, not once per pair.
  4. Compute C = A × Bᵀ. C[i][j] is the cosine similarity between candidate i and reference j.
  5. Reduce C per row to get the per-candidate aggregate (max, top-k, etc.).

C = A × Bᵀ is what CPUs and GPUs are built to do fast — an architecture-optimised path exists for it on every platform via concepts/matrix-multiplication-accumulate hardware primitives (CPU SIMD + FMA, GPU tensor cores).

Why this helps

  • Cache-friendly. Both matrices are contiguous; rows fit in L1/L2 and stream sequentially through SIMD units.
  • SIMD-native. The inner loop is D parallel multiply-adds with trivial data dependence — exactly the shape vector hardware accelerates. A single fma instruction accumulates 4 (AVX2) or 8 (AVX-512) doubles per step.
  • Amortised overhead. Any per-call setup (thread handoff, JIT warm-up, layout conversion) pays once per batch instead of M × N times.
  • Composable with downstream steps. The output matrix C is itself amenable to further vectorised operations (row-max for top-1, row-softmax for attention, threshold filtering).

Caveats

  • Batching overhead dominates if M or N is small. For a single candidate against a handful of references, the per-pair loop may still win. Netflix's Ranker keeps the single-item implementation for the ~98% of requests that ship one candidate.
  • The reshape alone is not enough. The first implementation must co-design memory layout — flat row-major buffers, not double[][] — and allocation strategy (patterns/flat-buffer-threadlocal-reuse), or the per-batch allocation + GC pressure can erase the kernel win. Netflix measured a ~5% regression on its first cut of batching precisely for this reason.
  • Library choice matters. Netflix tried BLAS via netlib-java and lost to JNI transition overhead + row-vs-column- major translation. Pure-Java SIMD via JDK Vector API won instead.

Seen in

Last updated · 319 distilled / 1,201 read