JAX-PCMCI

Causal Discovery in JAX

A bit of personal history first. When I started out in ML, the only resource I had was a mediocre laptop and Google Colab. I shied away from TPUs for a while; they were hard to program for. Eventually, I tried them, and they were fast. The catch is that programming for TPUs is difficult, but seemingly well worth the effort. With JAX, I found that if I kept looking there was always another optimization to be had. Later, when I switched to Kaggle for compute, I discovered that in Kaggle's old v3-8 cluster just 1 v3 outperformed the P100 they offered for my workload. So I just… learned to program in JAX. Every time I'd dabble in PyTorch or another eager framework, I'd feel the performance gap and go back to JAX.

The Problem

A long time down the line I ended up working on a study about causal discovery involving an uncommon type of neural network (more details on that will come later). The general gist is that I needed to compare my approach against established methods of causal discovery, and one of the most common baselines is PCMCI. If you're unfamiliar: PCMCI is a two-stage algorithm for time-series data. Given N variables measured over T time steps, it figures out which variables actually cause changes in which others, and at what time delay. Stage one (the PC1 phase) prunes spurious candidate links through iterative conditional independence tests. Stage two (the MCI phase) re-tests each surviving link while conditioning on the parents of both the source and the target, which screens off confounders and cuts the false-positive rate.

The standard implementation lives in a library called Tigramite. It was horrifically slow for my uses. It's CPU-only, and while it technically supports a GPU-accelerated independence test, that option also happens to just not work. I fought with it for a while before deciding: I'll make my own library.

Building It

The thing that made the rewrite worthwhile is that the core of PCMCI, thousands of conditional independence tests, is "embarrassingly parallel". Each test operates on an independent slice of the data and shares no mutable state with the others. In Tigramite they run sequentially; in JAX, you can vmap a single-test function across the entire batch and run hundreds simultaneously on a GPU. Then jit compiles the pipeline through XLA, fusing operations and killing Python overhead. I then implemented pmap to distribute across both GPUs, because Kaggle offers a T4×2 instance,

I would like to say there was just one hard part, but there were many, many, many, hard parts. If you're curious you can check out the git commits for all of the versions. There were difficulties figuring out when the padding is worthwhile, making it work on all devices (not just the ones I was using) and a couple of other things. Nothing impossible to fix, just a lot of trial and error. The biggest challenge was probably the padding and masking strategy. The number of tests in PCMCI varies wildly based on the data and the stage of the algorithm. In the early PC1 phase, there are often hundreds of thousands of tests to run, which is perfect for GPU parallelism. But as the algorithm prunes links, the test count can drop to just a few dozen, which leads to underutilization and slow runtimes. My solution was to pad the test batch up to a fixed size (I settled on 512) and mask out the invalid entries.

Memory management was another headache. A naive approach would materialize the full N×τ×N×τ test matrix at once, blowing past VRAM limits for larger systems. The batching module processes tests in chunks, auto-sizing batches based on available device memory. Numerical stability was a recurring theme too: partial correlations near ±1 produce infinite Fisher z-values, and the k-NN mutual information estimator can return negative values (a statistical impossibility) when samples are scarce relative to dimensionality. Every test includes clipping, floor values, and degeneracy guards. All of course implemented as branchless JAX operations to stay JIT-friendly.

The Independence Tests

The library comes with four conditional independence tests. ParCorr (partial correlation) is the fastest: regress X and Y each on the conditioning set Z, correlate the residuals, and compute a p-value via Fisher's z-transformation. It's a single JIT-compiled function and handles linear dependencies well. CMI-kNN captures nonlinear dependencies using the Kraskov-Stögbauer-Grassberger estimator: find each point's k-th nearest neighbor in the joint (X,Y,Z) space, count neighbors within that radius in the marginals, and take the digamma difference. CMISymbolic is a fast discretized alternative for when you need nonlinear detection but can't afford permutation testing. And GPDCond fits Gaussian process regressions and runs a distance correlation test on the residuals, which can pick up complex nonlinear structures that the others miss.

How It Turned Out

The result was shockingly good. Even with all the inefficiencies and without the later optimizations, it ran many times faster than Tigramite — although it was still taking hours for my dataset, so I kept improving it. My first couple of implementations (everything before version 1.3.0) admittedly had flaws that made the whole thing many times slower than it should have been. After many tricks and many optimizations it reached a state where I struggled to find further optimizations. On a single 4900HS laptop iGPU it runs PCMCI on a 20-variable, 250-timestep system in about 0.12 seconds. PCMCI+ (which adds contemporaneous link discovery) comes in around 0.23 seconds. As an added bonus, the scaling advantage grows with problem size: more variables and deeper lags fill GPU cores that would otherwise sit idle in a sequential implementation.

I figured there might be somebody out there in a similar situation to me. The dataset wasn't that large, and there are uses to running causal discovery on larger dataset. So I put the library on PyPI, set everything up, and got it working. Ironically enough, for its original purpose it worked a bit too well. It was quite accurate, and the speed of it meant I could no longer claim my method was 47x faster. That being said, it's now more truthful, so I am not particularly mad about it (and my method ended up still being quicker).

I later checked pypistats out of curiosity and was horrified to discover that a few people were actually using it. Which I was not expecting. That realization was part of what prompted me to make this blog. I'd noticed the JAX-PCMCI GitHub repo hadn't even been indexed by Google yet. So alas, unfortunately the burden of making it somewhat discoverable falls on me. As such here I am trying to bring a small amount of attention to this project.

As for the visual effect in the background, running with the theme of interesting minimalist aesthetics, this is meant to represent the actual PCMCI process. Each column of nodes is a time step, and the charcoal edges between them are the candidate associations that PCMCI considers during the PC1 phase. The red edges are the ones that survive — validated causal links that persisted through the conditioning gauntlet. The whole thing scrolls left like time-series data flowing through the algorithm. If you hover over a node you can see the other nodes that it has a relationship to.

Structural Causal Vocabulary

Time-lagged links. A causal link Xt−τ → Yt means the state of X at τ steps in the past influences the current state of Y. PCMCI systematically tests all such links up to a maximum lag τmax.

Graph mutilation. Pearl's do-calculus formalizes intervention. To compute the effect of forcing X to a value, you "mutilate" the graph — delete all incoming edges to X. This severs X from its natural causes, isolating only its downstream effects.

Colliders. In the structure X → Z ← Y, the node Z is a collider. X and Y are marginally independent — until you condition on Z, which paradoxically opens a spurious path between them. PCMCI must carefully navigate these structures to avoid both false positives and false negatives.

Hidden confounders. A latent common cause Z driving both X and Y creates the illusion of a direct X → Y link. The conditioning strategy in PCMCI's MCI step is specifically designed to screen off such confounding, provided the confounder's effects are captured by the included variables and time lags.