GeistHaus
log in · sign up

https://rachitsingh.com/rss.xml

rss
98 posts
Polling state
Status active
Last polled May 19, 2026 07:36 UTC
Next poll May 20, 2026 10:48 UTC
Poll interval 86400s
ETag W/"92bae8b8242907a7947a5c3daa00fcc9"

Posts

Training MoEs - efficiency
$$ \newcommand{\on}[1]{\operatorname{#1}} \newcommand{\dmodel}{d_{\on{model}}} \newcommand{\dexpert}{d_{\on{expert}}} $$

MoEs are a new major paradigm for scaling models upwards without necessarily increasing compute costs, and they seem to work quite well. There's a lot of interesting work here, but this blog post focuses on recent work on increasing efficiency of MoE kernels. I'm mostly writng about my understanding as I read it.

I'm reading two papers about scaling MoEs:

  • SonicMoE (Guo et al), which is from the Tri Dao and Stoica groups about new MoE training kernel optimizations, along with the associated blog post
  • Scaling Laws for Fine-grained Mixture of Experts (Krajewski et al) about how to change your MoE granularity as the flop and token budget increases (along with a new scaling law)

They actually use different definitions of "granularity"! They're very related:

  • $G_{\on{Sonic}} = d_{\on{model}} / d_{\on{expert}}$ is the ratio of the embedding or residual size and the intermediate size of a single expert
  • $G_{\on{Krajewski}} = d_{\on{ff}} / d_{\on{expert}}$ is the ratio of the "original" dense feed-forward intermediate size and the size of a single expert
  • in practice, usually $d_{\on{ff}} = 4\cdot d_{\on{model}}$ , so they're related, but the former is internally defined (i.e. a function of the architecture), and the latter is kind of dependent on a reference dense transformer model

Let's define granularity to be $G_{\on{Sonic}}$, i.e. the ratio between the embedding size and expert, since it's calculable without a comparison. $K$ is the number of activated experts per token, and $E$ is the total number of experts. SonicMoE defines sparsity as $\rho = K/E$.

The SonicMoE blog post has a nice description of some recent models, which I've fleshed out a little here:

modelyear$d_\on{model}$$d_{\on{expert}}$$K$$E$$G$$\rho$ Mixtral 8x22B2024614416384280.3750.250 DeepSeek V220245120153661603.330.0375 DeepSeek V3.220257168204882563.50.031 Kimi K2.520257168204883843.50.021 Qwen3-Next-80B-A3B-Instruct20252048512105124.00.020 Qwen3.5-397B-A17B202640961024105124.00.020 Qwen3.6-35B-A3B2026204851282564.00.031 Arcee Trinity Large20263072307242561.00.016 z.AI GLM-5.120266144204882563.00.031 MiniMax M2.520263072153682562.00.031 Ant Ling 2.5-1T20268192204882564.00.031 DeepSeek V4 Flash20264096204862562.00.023 DeepSeek V4 Pro20267168307263842.330.016

Note that many of the above models have a "shared expert" (which is always activated), but that's not included in the sparsity calculation above. There's a lot of variability in $G$ and $\rho$ even in models from 2026!

Krajewski et al. showed that the compute optimal hyperparameters for MoE models are increasingly granular (see Table 2). For example, for a pretraining run like DeepSeek V4 Pro's, which was apparently approximately 1e25, we would want $G=64$ (and only $8T$ tokens), which is way bigger than even the most granular model above. So how do we get there?

There are two big issues1:

  1. As we increase the granularity (i.e. make the expert size $d_{\on{expert}}$ smaller relative to the embedding size $d_{\on{embed}}$), we reduce the number of flops used. In order to keep the flops constant (i.e. use our whole budget), we'd need to activate more experts; however in the forward and backward pass we have activations which are not dependent on $d_{\on{expert}}$ but on $d_{\on{embed}}$, which means that they just get bigger as we increase $K$).
  2. Increasing the granularity decreases the arithmetic intensity of the kernels, which means we quickly become memory bound.

The latter is really in the paper. Let's take a look at the forward pass and assume that when we concatenate the $M$ rows from the tokens routed to the expert we're looking at, it forms a matrix $X_e \in \mathbb{R}^{M \times d}$. We know that $M$ is on average $T\rho$, so let's consider that best case. Then for SwiGLU, we have to up-project, gate, and down-project, which are respectively multiplications by matrices of size $(d_{\on{model}}, d_{\on{expert}})$ and back. Each matmul is $2 d_{\on{model}}d_{\on{expert}}M$ FLOPs, for a total of $6 d_{\on{model}}d_{\on{expert}}T\rho$ FLOPs which is the numerator on the arithmetic intensity calculation.

For the denominator, assuming everything is bf16 (2 bytes), we have two operations that we think about in terms of HBM bandwidth: the up-projection + SwiGLU as a single operation, then the down projection as a separate operation (we don't send the intermediate pre-gating value back to HBM, presumably because we can easily fuse this:

  • The first operation in terms of bandwidth is reading the input ($2M\dmodel$ bytes), reading the two matrices ($4\dexpert\dmodel$ bytes), and writing out the result to GMEM ($2M\dexpert$ bytes).
  • The second operation is reading that result ($2M\dexpert$ bytes), reading the down-proj matrix ($2\dexpert\dmodel$) and writing the result ($2M\dmodel$)

So the arithmetic intensity is: $$ \begin{aligned} \on{ArithInt} &= \frac{6M\dmodel\dexpert}{4M\dexpert + 6\dexpert\dmodel +4M\dmodel} \ &= \frac{3}{\frac{2}{\dmodel} + \frac{3}{T\rho} + \frac{2}{\dexpert}} \end{aligned} $$ But since they've defined above $G = \dmodel / \dexpert$, the denominator is really 2 terms: $3/T\rho$, and $(2 + 2G)/\dmodel$. So as we increase the granularity $G$ OR make it more sparse (decrease $\rho$), we are decreasing the arithmetic intensity, pretty much proportionately.

The above specific bookkeeping isn't really that important to understanding the intuition, though. When you are increasing the granularity, you're really making the matrices we multiply by a little smaller ($\dexpert$ is getting smaller), which is making the matmuls less "square". If you squint at it, the above is really the same as the arithmetic intensity of a regular $(M, K, N)$ matmul, and if you make $K$ smaller, you get lower arithmetic intensity (this is a different $K$ from the number of experts activated per token).

Another way to look at it: to keeps FLOPs constant as you increase granularity, you need to decrease the size of each expert and increase the number of activated experts, but there's activations that are linear in the base model size $\dexpert$, which remains constant, and in $K$, the number of activated experts.

Naive MoE kernels

There are several MoE kernels available before this work, like MoMoE and ScatterMoE. Below I'll summarize how a very naive kernel works using SonicMoE's blog post (I think both of the above are more efficient than this naive version).

Naive implementation (using total tokens $T$):

  1. Gather the input $X$ of shape $(T, \dmodel)$ into an expanded form, repeating each input $K$ times so that it's $X_{\on{gathered}}:(TK, \dmodel)$
  2. Apply a grouped GEMM to $X_{\on{gathered}}$ along with the corresponding computed expert offsets (i.e. routing plan) and the up-project and gating weights (which are of course $(\dmodel, 2\dexpert)$ in size). The result is $H: (TK, 2\dexpert)$.
  3. Apply the SwiGLU operation to this to get the pre-down projection activations, $A: (TK, \dexpert)$.
  4. Apply the grouped GEMM with the down projection (and the expert offsets again) to get something of size $Y: (TK, \dmodel)$.
  5. Scatter and aggregate with the routing scores $S$ to get an output of size $(T, \dmodel)$

The real issue here is that as we get more granular, and more sparse, the activations of size $TK\dmodel$ get really big, as mentioned above. These have to fit in HBM and also transfer to compute and back. There's a similar issue with the backward pass.

SonicMoE

There's some smart fusing here, which is basically oriented around trying to avoid materializing anything that looks like the above shape. You can see the precise set of operations here.

So far so good - obviously a clever team that has found a smart way to avoid caching or materializing things that aren't necessary. The core idea is to avoid anything that is size $\mathcal{O}(TK\dmodel)$.

  • In the forward pass that's the $X$ after the gather operation (i.e. duplicating rows of $X$) but before the up-proj, and the $Y$ before the scatter-and-sum operation. SonicMoE fuses the up-proj and down-proj to avoid materializing the big $X$ and $Y$, and doesn't cache the results of the intermediate operations except for $H$, the output from the first matmul of the SwiGLU above.
  • It's more complicated for the backwards pass. We start with the grad of the output, and we need to compute $dX$, $dS$ (the routing scores), $dW_1$, and $dW_2$, the grads of the SwiGLU weights. If you store $S$ and $Y$ during the forward pass, computing $dS$ is just an inner product with $dO$. However, if we don't store it, the authors find a workaround: you can think of the application of $S$ as happening before the down-proj, and think of the result as $A'$, then you can avoid ever needing $Y$ at all.

Here's a more detailed explanation. Normally, to compute $dS$, we take the inner product of $dO$ and $Y$, which requires caching $Y$. However, let $A' = \on{diag}(S) \cdot A$, where it's after we route/gather $S$ to the right shape for $A$ ($TK \times \dexpert$). Then $dA = \on{diag}(S) \cdot dA'$ 2, and since $A'$ is multiplied directly by $W_2$ to get $O$, we can easily compute $dA'$ and $dA$ without storing anything besides the original inputs. The interesting part is this from the blog post and the paper:

$$\begin{align*} dS = \langle dO, Y\rangle = \langle dO, AW_2\rangle = \langle dOW_2^\intercal, A\rangle = \langle dA', A\rangle\end{align*}$$ Very cool. Then, since $A'$ is used to compute $O$ from $W_2$, you can compute $dW_2$ using that (once you've computed it in the backwards pass) and we can compute $dW_1$ since we store $H$. There's a detail here about exploiting L2 cache locality as well to make this faster.

In my head, the overall intuition for how SonicMoE avoids caching these big tensors is something like: this is kind of like avoiding materializing a super big matrix when you can instead do something like a torch.Tensor.scatter_. This isn't a particularly good intuition yet.

QuACK

This is where it gets really interesting. Generally speaking when we write optimized kernels,we keep a certain target hardware in mind (e.g. H100 or B200). In CUDA, this usually means a specific compute architecture. There were significant changes to the hardware in between those two generations, and it took a pretty long time for kernels to fully support Blackwell (Triton had to rethink a lot of details with Blackwell, for example).

SonicMoE supports a LOT of different architectures - H100, B200, and also surprisingly the SM120 architecture in consumer Blackwell (and RTX 6000, a chip I work with). How do they do it? One of the trends in kernel design lately is that we are moving towards increasing levels of warp specialization in kernels, with producers and consumers. SonicMoE is written by taking advantage of a library, QuACK, that allows you to split the computation into 3 stages: prologue, mainloop, and epilogue, and customize the stages. It turns out that all of the SonicMoE kernels can be written by customizing the prologue and the epilogue, and allowing the mainloop to be a generic MMA (WGMMA or UMMA depending on hardware).

Why? The grouped GEMM kernels that are necessary for MoE mostly differ in how they scatter or gather the producer or epilogue steps. For example (my understanding), when fusing the gathered X up-projection, this is really a matmul with a producer that loads from different tokens (or different expert weights) based on the expert-token choices. This means the compute expensive part (and usually hardware-specific area) is specialized to the middle, but the flexibility (scatter/gather/activations/etc.) is in the producer and epilogue.

Scheduler

This is something I don't have any experience with. I haven't really understood this part except to understand that Blackwell is different from Hopper since tcgen05.mma is asynchronous and skips the registers, so a more complicated scheduler (which you can apparently customize) is necessary to keep track of which TMEM buffers are locked or ready.

General notes

One thing I'm always wondering when reading this is that it often feels like NVIDIA is writing the hardware and software abstractions with this stuff in mind. Some of the advancements in SonicMoE can be described as "how can we make a grouped varlen GEMM really just like a regular GEMM?" which means that the hardware primitives they need definitely exist because of how central matmuls are, so it's not crazy that everything exists, but it's still definitely very impressive.

  1. I'm just rewriting the ideas from the blog post as I parse through them; this isn't intended to explain anything better than the blog post. Feel free to point out mistakes in understanding.

  2. Excuse all the abuse of notation; this is just my reading notes but it's easier if you have all the shapes.

https://rachitsingh.com/reading/training-moes-efficiency/
About

Hi, I'm Rachit. I model energy systems (generation, transmission, and loads) using machine learning, and like to write in my spare time.

I previously spent time at Stripe, Jane Street, DE Shaw, HRT and in the Harvard NLP group. In college, I trained some LSTMs in Lua torch(7), and spent some time thinking about statistical inference.

Afterwards, I worked as a quant and programmer at DE Shaw.

I was part of the YC W22 batch, though I'm working on a new company now.

I mostly like to work on technical problems. A lot of my earlier work as a quant was on optimization of various kinds. You might be able to nerd-snipe me into optimizing your code (especially if it's open source). I also enjoy designing things and getting my hands dirty.

Please reach out if you're interested in talking about tools, machine learning, or writing. I'm generally happy to grab a coffee if you're in NYC, unless it's a busy time of year.

https://rachitsingh.com/about/
Applied research failure modes

Applied research teams slow down not because they lack good ideas but because of simpler issues: unclear metrics, confusing experiment design, poor iteration loops, and vague takeaways. I've been working on applied research1 in some form for my entire career so far. These are notes for early-career quants, ML researchers, and other applied computer scientists trying to improve their research process inside an organization.

Continuously estimate value

Before attempting an idea, estimate both the probability of it working and the upside if it does. Be wary of attempting high risk, low return work, even if it sounds cool. One important tool here is to upper bound the value of the project in some way if possible ("What if I had an oracle for X? What would happen? Is it even worth predicting X?").

Obviously you have to do some cool stuff too, but there's little worse than working on a hard project for a long time, succeeding, and realizing that no one cares.

Do not make every project high risk

There is one thing worse. Research is almost by definition high risk. If you only choose high risk, high return projects, and you continuously get burned (with a 5% hit rate you will fail on average 19 projects before you get a hit), you are likely to burn out. Some people are able to deal with this, but most people cannot.

Keep some low risk, low reward tasks around to keep your momentum up. I like to keep some optimization tasks in my queue for this reason, since it creates visible progress even when we're stuck as a team on the main research metrics.

Aggressively chunk your work

Research is full of ideas that feel interlocked. Figuring out split these ideas into separate chunks that can be parallelized or landed independently is hugely valuable. This is traditionally the area of a project manager, but often researchers need to be their own project manager.

Here's an example: let's say you think that incorporating some new training dataset will improve your model. Ingesting, cleaning, and organizing that dataset is a useful chunk that could be repurposed even if your model doesn't improve on the goal task!

Design convincing experiments

In applied research, you know your audience: it's your peers and internal decision makers. You can ask them questions before you do work. Examples:

  • Before running an experiment, check whether other people would draw the same conclusion from the hypothetical results. (if they won't be convinced, design a different experiment).
  • Register their hypotheses if possible. Treat all surprises to that prior as a useful takeaway.
Agree on metrics

It's very frustrating to realize that you're working on metric A, but your coworker only cares about metric B. Agree on what you're measuring, and how. It's easiest if everyone uses the same metric calculation code. Reality is rarely simple enough for one metric, but running an experiment and cherry-picking whichever metric improved is going to meet a lot of skepticism.

Build a reputation for calibrated and justified claims

It's very frustrating to realize that a core assumption you'd been making is based on a claim someone else made that is just inaccurate. This can seriously derail research projects. In applied research, precision is a form of respect for other people’s time.

Here's an example: if you say feature X and Y are correlated, and you draw that conclusion using some dataset D, it's useful to state that precisely - "I think X and Y are correlated when looking at dataset D". For example D might be data from the year 2024, which doesn't generalize to 2025. Maintain high skepticism of generalization. Narrow claims are not weaker, they are easier to trust.

Negative results can be useful, but false negatives are worse than false positives

Depending on what you work on, you might have a true hit rate on projects of less than, say, 10%. Save the negative outcomes so that you can build intuition, or if possible, communicate this intuition to others.

A false positive usually burns a bounded amount of time: someone tries to productionize the idea, the live or more realistic evaluation often catches it (in a good org), people are annoyed, and you might lose some credibility.

A false negative can be more expensive because it can remove an idea from the team’s search space. Paradoxically, the more influential or trusted you are, the longer it will be before someone attempts it again. This isn't necessarily a disaster, but keep this in mind when you run a "quick test" of an idea. If you don't give it a serious attempt, don't overclaim that the whole set of ideas doesn't work.

A common example is "I tried that feature before, it doesn't do anything". Be specific! What did you try?

Look at your data, not just summary statistics

Many research mistakes are hidden by aggregation, especially in machine learning.

Look at your features. Look at the response. Plot some things. Look at a specific example that is being passed through your network, and look at embeddings. Look at the loss.

Examples I've caught before: two hand-engineered features were computed by two different functions, but were actually identical. Another feature wasn't available after a certain point, and was just getting ffilled for years.

Know if you are compute-bound or thinking-bound

All of my work so far is in machine learning or quant finance. Running experiments is automated, but can sometimes have a long runtime. Are you compute-bound (not enough GPUs?) or thinking/analysis bound (can't set up experiments quickly enough)? Tooling that automates experiment setup, babysitting, and analysis can move you from thinking-bound to compute-bound surprisingly quickly.

If you're compute bound, take note of when resources are free. Keep a notebook (or set of git branches) so that you can launch an experiment quickly. Even in resource constrained organizations there's a lot of wasted compute e.g. overnight or over the weekend. Good queue systems help a lot. Make your work easily preemptible and resumable. Make it robust to being restarted on another machine.

Decrease iteration time

If your experiments need to convert a lot of data from disk into an in-memory format to start, can you decrease that? Can you use faster disks or parallelize harder? Can you get good at generalizing from small changes (small dataset, small model) to the "full" set of changes?

There's a serious slowness introduced into an organization when a basic experiment takes longer than a few hours (or god forbid a day).

Standardize your plots and evals

Making charts via custom notebooks is slow, surprisingly error prone, and (most critically) forces your reviewer to do extra work understanding what you're showing. Often they'll request another evaluation or plot. Better to agree on these ahead of time and just run them every time on every experiment. I like papermill.

  1. (as opposed to academic research)

https://rachitsingh.com/applied-research/
Learning in low precision

$\newcommand{\on}[1]{\operatorname{#1}}$ I think I should be writing down some of the things I learn a little more frequently, so I'll try with a small note here.

I read this paper recently: Training LLMs with MXFP4. MXFP4 is a 4-bit OCP low precision format most well known for use in OpenAI's gpt-oss models. It uses an E2M1 layout (1 sign, 2 exponent, 1 mantissa bit) for the main data, with an extra shared scale value (E8M0) for every 32 elements (blockwise scaling). Note that this discussion is mostly about the weight format; the activation format would probably be MXFP8 or BF16. In practice the main contribution of this paper is attempting to show that you can do it all in MXFP4, I think, including the activations (see section 3). In practice they show it is possible to use MXFP4 in linear backward passes, while the forward pass is in FP8 or BF16.

Note This is the third in a sequence of papers by Albert Tseng on LLM quantization; I didn't read the other two, though they were at ICML and NeurIPS.

There are two major advantages I see to using 4-bit weights at inference time:

  • memory: transferring weights from GMEM to SMEM/RMEM/TMEM is much faster if there's ~1/4 of it
  • compute: if the hardware supports compute kernels taking a 4-bit weight and multiplying by a MXFP8/BF16 activation, then you can get a speedup in raw FLOPS supported

The former is much easier, and the latter usually requires Blackwell (or similar).

This paper shows a way to do training with MXFP4 as well. Since training is usually so compute bound, this can be a huge benefit if you can deal with the instability. This paper introduces two methods to reduce the variance of the gradient estimates (sound familiar?!):

  1. Stochastic rounding, i.e. what my friends do to me when we go out for dinner
  2. Random hadamard transforms, something I've heard of many times in a compressed sensing context but didn't remember well

Let's dive in.

Stochastic Rounding

The idea here is that if your low precision data type can only represent a small set of values (e.g. just {1.0, 2.0}) and you need to round a higher precision data type down, instead of just rounding 1.6 to 2.0, you can round it to 1.0 sometimes (with a probability $p$). The goal is for the expectation to be the same, i.e. $\mathbb{E}[\on{round}(x)] = x$. In this specific case we'd want $p = 2 - x$, or basically an inverse distance.

When I first heard about this, I thought that this would be way too expensive to use, but it turns out there's a hardware thing called dithering which makes this almost no cost.

To make this work you also need to rescale the original values by 3/4 first (very interesting demonstration of how this works, but I think I understand: the original transform rescales to (-8, 8), but any values greater than 6 are clipped down, which introduces bias). We have to undo this later of course when applying the GEMM using this quantized weight.

Then we apply the stochastic rounding, which in expectation is the right number to multiply our activations by.

Random Hadamard Transform

This transform is $\on{RHT}: x \to HSx$, where $H$ is a Hadamard matrix, and $S$ is a random sign vector. Hadamard matrices are a specific series (not a class of matrices like upper triangular, but a specific set parametrized by a size $n$) where $H_n$ is recursively defined and orthogonal. From that, we conveniently have that $\on{RHT}(X)^\top \on{RHT}(Y) = X^\top Y$, which makes it easy to use in linear layers.

There's a very nice theorem (what I like to call a Mitzenmacher-style theorem) that says that applying the RHT before quantization changes the variance dependence from $b$ to $\log b$:

Theorem Let $A$ and $B$ be two size-$b$ vectors, and let $\mathcal{Q}$ be the stochastic rounding quantization operation. Then, the variance of $\mathcal{Q}(A)^\intercal\mathcal{Q}(B)$ is $\mathcal{O}(b\Delta^4\lVert A\rVert_\infty\lVert B\rVert_\infty)$ and the variance of $\mathcal{Q}(HSA)^\intercal\mathcal{Q}(HSB)$ is with high probability $\mathcal{O}(\log(2b/\epsilon)\Delta^4\lVert A\rVert\lVert B\rVert)$

with probability parametrized by $\epsilon$. The precise form is in the paper. At face value the result is nice, because $\log(b)$ is obviously better than $b$. However, we know the max-norm is less than the L2 norms, and in fact the L2 norm grows approximately as $\sqrt{b}$... which should cancel out the effect?

I dug slightly into this, and it seems like a desirable property of LLM weights is that they're incoherent, i.e. they are entries without a few large outliers, because they can be quantized without significant accuracy loss. Approximately this definition is "no individual entry is too big compared to the Frobenius norm". My intuition at this point is that the Hadamard transform preserves the energy (it's an orthogonal transform), but it spreads out a very big outlier in one coordinate into all the coordinates, via this random sign. Then, since we share a scale amongst many values when quantizing, the quantization error goes way down. The specific bound used here is a Hoeffding-style concentration bound for signed sums, which gives the "sub-Gaussian" shape.

Figure 2 shows the relationship between these bounds (I didn't think about the experimental setup) and they appear to both be kind of linear after a while, though indeed the Hadamard-transformed version is lower variance.

But how do we compute it? It's not free to multiply your data by a matrix, even if it's random. There's a fast algorithm to apply the full RHT in $n \log n$ time, but I think the issue here is that we'd be mixing across the batch dimension, which is very expensive if you're doing any kind of data parallel work (different rows of the data live on different GPUs). They have a solution here that seems to involve some kind of blockwise RHT (the transform is a kind of "mix", and you can say "I want to only mix across the first X bits, then the next X bits, etc.", I think). I didn't have time to understand this.

Evaluation

This is a pre-training paper, which means every eval is extremely expensive.

They didn't have access to FP4 hardware (Blackwell) when running these experiments, which is definitely a bummer (this is from late 2024, I think). Instead they used a neat Microsoft library that emulates MX-format data types in Pytorch. They attempt to measure the real slowdowns / speedups, but it's not clear to me that the comparisons are valid to e.g. specific hardware.

Related results - more Hadamard rotations

I just saw this new paper by the Together AI lab (aka Tri Dao & co), focused on inference, which also uses a block diagonal Hadamard rotation:

Our central finding is that a simple design—token-wise INT4 quantization with block-diagonal Hadamard rotation—consistently achieves the best accuracy–efficiency trade-off.

This seems to be quite common in FP4 and INT4 training and inference, e.g. see this paper.

https://rachitsingh.com/reading/learning-in-low-precision/
Learning in low precision

$\newcommand{\on}[1]{\operatorname{#1}}$ I think I should be writing down some of the things I learn a little more frequently, so I'll try with a small note here.

I read this paper recently: Training LLMs with MXFP4. MXFP4 is a 4-bit OCP low precision format most well known for use in OpenAI's gpt-oss models. It uses an E2M1 layout (1 sign, 2 exponent, 1 mantissa bit) for the main data, with an extra shared scale value (E8M0) for every 32 elements (blockwise scaling). Note that this discussion is mostly about the weight format; the activation format would probably be MXFP8 or BF16. In practice the main contribution of this paper is attempting to show that you can do it all in MXFP4, I think, including the activations (see section 3). In practice they show it is possible to use MXFP4 in linear backward passes, while the forward pass is in FP8 or BF16.

Note This is the third in a sequence of papers by Albert Tseng on LLM quantization; I didn't read the other two, though they were at ICML and NeurIPS.

There are two major advantages I see to using 4-bit weights at inference time:

  • memory: transferring weights from GMEM to SMEM/RMEM/TMEM is much faster if there's ~1/4 of it
  • compute: if the hardware supports compute kernels taking a 4-bit weight and multiplying by a MXFP8/BF16 activation, then you can get a speedup in raw FLOPS supported

The former is much easier, and the latter usually requires Blackwell (or similar).

This paper shows a way to do training with MXFP4 as well. Since training is usually so compute bound, this can be a huge benefit if you can deal with the instability. This paper introduces two methods to reduce the variance of the gradient estimates (sound familiar?!):

  1. Stochastic rounding, i.e. what my friends do to me when we go out for dinner
  2. Random hadamard transforms, something I've heard of many times in a compressed sensing context but didn't remember well

Let's dive in.

Stochastic Rounding

The idea here is that if your low precision data type can only represent a small set of values (e.g. just {1.0, 2.0}) and you need to round a higher precision data type down, instead of just rounding 1.6 to 2.0, you can round it to 1.0 sometimes (with a probability $p$). The goal is for the expectation to be the same, i.e. $\mathbb{E}[\on{round}(x)] = x$. In this specific case we'd want $p = 2 - x$, or basically an inverse distance.

When I first heard about this, I thought that this would be way too expensive to use, but it turns out there's a hardware thing called dithering which makes this almost no cost.

To make this work you also need to rescale the original values by 3/4 first (very interesting demonstration of how this works, but I think I understand: the original transform rescales to (-8, 8), but any values greater than 6 are clipped down, which introduces bias). We have to undo this later of course when applying the GEMM using this quantized weight.

Then we apply the stochastic rounding, which in expectation is the right number to multiply our activations by.

Random Hadamard Transform

This transform is $\on{RHT}: x \to HSx$, where $H$ is a Hadamard matrix, and $S$ is a random sign vector. Hadamard matrices are a specific series (not a class of matrices like upper triangular, but a specific set parametrized by a size $n$) where $H_n$ is recursively defined and orthogonal. From that, we conveniently have that $\on{RHT}(X)^\top \on{RHT}(Y) = X^\top Y$, which makes it easy to use in linear layers.

There's a very nice theorem (what I like to call a Mitzenmacher-style theorem) that says that applying the RHT before quantization changes the variance dependence from $b$ to $\log b$:

Theorem Let $A$ and $B$ be two size-$b$ vectors, and let $\mathcal{Q}$ be the stochastic rounding quantization operation. Then, the variance of $\mathcal{Q}(A)^\intercal\mathcal{Q}(B)$ is $\mathcal{O}(b\Delta^4\lVert A\rVert_\infty\lVert B\rVert_\infty)$ and the variance of $\mathcal{Q}(HSA)^\intercal\mathcal{Q}(HSB)$ is with high probability $\mathcal{O}(\log(2b/\epsilon)\Delta^4\lVert A\rVert\lVert B\rVert)$

with probability parametrized by $\epsilon$. The precise form is in the paper. At face value the result is nice, because $\log(b)$ is obviously better than $b$. However, we know the max-norm is less than the L2 norms, and in fact the L2 norm grows approximately as $\sqrt{b}$... which should cancel out the effect?

I dug slightly into this, and it seems like a desirable property of LLM weights is that they're incoherent, i.e. they are entries without a few large outliers, because they can be quantized without significant accuracy loss. Approximately this definition is "no individual entry is too big compared to the Frobenius norm". My intuition at this point is that the Hadamard transform preserves the energy (it's an orthogonal transform), but it spreads out a very big outlier in one coordinate into all the coordinates, via this random sign. Then, since we share a scale amongst many values when quantizing, the quantization error goes way down. The specific bound used here is a Hoeffding-style concentration bound for signed sums, which gives the "sub-Gaussian" shape.

Figure 2 shows the relationship between these bounds (I didn't think about the experimental setup) and they appear to both be kind of linear after a while, though indeed the Hadamard-transformed version is lower variance.

But how do we compute it? It's not free to multiply your data by a matrix, even if it's random. There's a fast algorithm to apply the full RHT in $n \log n$ time, but I think the issue here is that we'd be mixing across the batch dimension, which is very expensive if you're doing any kind of data parallel work (different rows of the data live on different GPUs). They have a solution here that seems to involve some kind of blockwise RHT (the transform is a kind of "mix", and you can say "I want to only mix across the first X bits, then the next X bits, etc.", I think). I didn't have time to understand this.

Evaluation

This is a pre-training paper, which means every eval is extremely expensive.

They didn't have access to FP4 hardware (Blackwell) when running these experiments, which is definitely a bummer (this is from late 2024, I think). Instead they used a neat Microsoft library that emulates MX-format data types in Pytorch. They attempt to measure the real slowdowns / speedups, but it's not clear to me that the comparisons are valid to e.g. specific hardware.

Related results - more Hadamard rotations

I just saw this new paper by the Together AI lab (aka Tri Dao & co), focused on inference, which also uses a block diagonal Hadamard rotation:

Our central finding is that a simple design—token-wise INT4 quantization with block-diagonal Hadamard rotation—consistently achieves the best accuracy–efficiency trade-off.

This seems to be quite common in FP4 and INT4 training and inference, e.g. see this paper.

https://rachitsingh.com/learning-in-low-precision/
Claude Code to (my) exhaustion: a DP worklog

Recently at work I had a side quest on the following task: choose a binning scheme for some 1-dimensional data that has an annoying distribution. Quantile binning wasn't great (at the sparse tails it produces very wide bins, and in the center very tiny ones), and equal width was also out.

Basically I wanted to be able to specify a number of bins (n_bins) and get some "reasonable" bin edges out.

def compute_bin_edges(data: np.ndarray, n_bins: int) -> np.ndarray:
    ...

Well, one nice definition of "reasonable" is the Jenks natural breaks or Fishers natural breaks1: pick the bins so that the within-bin variance is minimized. This method seems to be pretty widely used in mapping and visualization (especially GIS) libraries to try to divide datasets (e.g. color parts of a map). It was even featured in a ThePrimeagen video.

The core algorithm would make a nice dynamic programming homework problem:

Problem Given $N$ data points $x_1, x_2, \ldots, x_N$, find a bin assignment ${\mathcal{B}_1, \mathcal{B}_2, \ldots, \mathcal{B}_B}$ for each point that minimizes:

$$\sum_{i = 1}^B \sum_{x \in \mathcal{B}_i}^{\left|\mathcal{B}_i\right|} \left(x - \mu_i\right)^2, \mu_i := \overline{\mathcal{B}_i}$$

Here's part of the solution (the DP definitions), hidden in case you want to try to think it through for yourself:

Dynamic programming setup Sort the data. Let $\on{D}[i, m]$ be the optimal cost of clustering $x_1, x_2, \ldots, x_m$ into $i$ clusters. Let $\on{T}[i, m]$ be the first point in the rightmost cluster when assigning $x_1, \ldots, x_m$ to $i$ clusters. Can you figure out how to use $\on{D}[i-1]$ and $\on{T}[i-1]$ to calculate $\on{D}[i]$ and $\on{T}[i]$?

There are 3 clean ways ways to do this2:

  1. A simple quadratic ($\mathcal{O}(k \cdot n^2)$) version that just takes advantage of optimal substructure (i.e. classic DP, no optimizations) (Jenks, 1977) .
  2. A slightly more complicated log-linear variant ($\mathcal{O}(k \cdot n \log n)$) that also takes advantage of something called the concave Monge property (but is just a way to compute the DP matrices faster) (Wu, 1989) 3.
  3. A shiny fast linear variant ($\mathcal{O}(k \cdot n)$) variant that uses a more complicated algorithm (SMAWK) to find the minima of the DP matrices (the goal) while computing only what's necessary (Wu, 1991) . Note that this is linear iff the input is sorted (for the rest of this we'll assume the data is pre-sorted). (Song, 2020) also implemented an in-place search space reduction (which to be honest was what I thought SMAWK was already doing).

To be clear there are plenty of packages to do this in Python already with good enough performance. There's also a very feature rich reference implementation by some of the authors of the 2020 paper in R/C++. But I was interested in learning about the linear algorithm4 and also interested in some extensions (adding min/max width constraints, $k$-medians, etc.), which I figured would be relatively easy once I had the core algorithm implemented.

So I figured, great task for Claude Code, right? This is:

  1. A clear problem description with 3 variants that are pure functions all with the same API
  2. At least the first two solutions are probably well represented in LLM training data
  3. It's easy to generate more test data (random()!) and there is a unique correct solution for every input
Giving Claude context

Since there was a high quality reference implementation for these algorithms, I cloned the repository and giving it access to the code would be useful (I put this in a reference/ folder, and put directions in CLAUDE.md) I also downloaded several papers and put them in .tex and .pdf form into the repository: (Grønlund, 2017) , (Wu, 1991) , and (Aggarwal, 1987) .

Of course as usual I created a CLAUDE.md with plenty of detail and a TODO.md with an organized plan for implementation, which looked something like this:

  1. Make really good test cases (i.e. test for degenerate distributions, really sparse in tails, really sparse in middle, etc.)
  2. Implement the QUADRATIC algorithm, and for small test cases check it for correctness against a brute force version
  3. Implement the LOGLINEAR algorithm and verify that it gives exactly the same result as the QUADRATIC
  4. Implement the LINEAR algorithm and verify that it gives exactly the same results as the other two algorithms

Of course, I had implicit ideas about how this should be done, but I didn't manage to write this down in the TODO.md or background information. I mentioned that it would be great to implement this in Python + Numba to keep the final implementation readable / easy to change (Numba is fantastic these days). Unfortunately, in the CLAUDE.md I also described how I wanted to eventually add min/max_width constraints.

What went well

Claude was able to relatively quickly deliver working code for the core QUADRATIC and LOGLINEAR algorithms.

What was not so good
  1. Claude read the reference code and saw a structure that switched based on user input between (QUADRATIC, LOGLINEAR, LINEAR) algorithms and also between $k$-means and $k$-medians, and it decided to mimic that structure (let's call it general_ckmeans). This overcomplicated the code early on (a lot of jumping between functions) and made it hard to read. I told it not to do that and asked for the 3 functions that I wanted implemented as separate interfaces.
  2. Claude also saw code in the reference implementation that tried to estimate what the optimal n_bins was and implemented a LOT of new code related to that; this also overcomplicated the data flow a lot.
  3. It generated test cases before doing the above (as I asked), but since it left general_ckmeans around, the tests were wired against the old implementations, so when it ran the tests, they passed (that was a correct implementation!). I realized that if you asked for n_bins that were more than the length of the data it returned essentially an error, so I explained what I wanted instead (adjust n_bins down / return fewer bins) and it added both a new test (as I explicitly asked) and the code changes, and it failed (still old implementation).
  4. Claude decided to add min/max_width arguments to the implemenation while writing it (seeing my desire to do this later), but since it didn't see code to implement it in the reference implementation, it just ignored those arguments. It apparently added some test cases for those arguments, but then when those failed, it read its code, realized it was ignoring those arguments, and decided to silently make the tests pass "for now". Given that I wanted to do test-driven development, this was super frustrating.
  5. I told Claude to never do that again (failing tests is ok unless the current goal is to solve that test specifically), and it made those tests actually run. However, my trust in Claude was shaken at this point so I checked the testcases - there are still some passing constraint-checking testcases. It turned out the passed min_width and max_width in those cases was just too lenient, so the optimal bins (without constraints) also passes.
  6. I wanted a bit more confidence that these bin assignments were optimal, so I asked it to do an iterative refinement check, by basically wiggling a random bin edge until it changed an assignment, then checking if that assignment was better - essentially, are we at a local minimum? Claude added this test, but it took a while to explain what I wanted, and to verify that Claude did this correctly.
  7. Claude decided around this point that it wanted to implement the min/max width constraints (I had added it to the TODO list at the end by this point) and it decided to do this by just wiggling the bin widths after optimization until they met the min/max constraints. It only did this if min/max bin width were passed, and since I didn't have a way to compute the optimal answer in that case, I didn't pass in min/max bin width in the first set of correctness tests. So it was giving incorrect results (a greedy approach like the above doesn't work), and I didn't know.
  8. Eventually I noticed the constraint tests were re-activated and tried to investigate; I also expanded the local minima tests to include min/max width constraints, which promptly failed. Investigating led me to tell Claude to stop implementing min/max width constraints and delete the incorrect code.
  9. At this point we're trying to implement the LINEAR variant. It tries for a second, decides it's too hard, and decides to implement an approximation (this decision happens in thinking but wasn't communicated to me). This was very frustrating and I ask it what context it needs to implement this correctly, and get those references (some of the papers, it needs some human description of what part of the reference codebase implements this, etc). Somewhere around here I reset the session and context.
  10. It implements the LINEAR variant with a lot of guided help from me, e.g. telling it to check the computed cost/backtracking matrices against the LOGLINEAR variant (which we have to refactor to extract, etc.). I don't pay attention to details besides "is this implementing it the right way" because, frankly, the LINEAR algorithm is complex. Maybe the AI can do it for me. The only exception is when it started segfaulting, I told it to turn the recursive algorithm into an iterative one.
  11. The local minima test cases are failing for the linear variant, which Claude eventually concludes is a "known issue" (what?). I ask it about this and it says "there's probably a subtle bug in these testcases". I ask it to check if the new bin assignment that is being generated from the tweaked bin edges is even different... it isn't. Since the cost function went down, that means the cost function is incorrect - I realize at this point that this whole set of testcases doesn't really work.
  12. After some detailed directing, it makes those test cases correct again, and eventually makes the LINEAR testcases pass. Wow!

My overall feeling after this was basically that while you can use Claude or Copilot or Cursor to generate the tests the first time, it's probably best to think hard about it yourself afterwards, verify that it's the behavior you want, organize them by which TODOs will solve them, and then don't let Claude touch the tests. It is just so willing to silence failing tests "for now" and it really leads to confusion. I also think the bias towards "must solve testcases" is so strong that it will kind of jump ahead and solve them even if that's not the next task on the TODO.

Where I gave up

Right after the above series of steps I asked it to profile the runtimes of each algorithm. It produces some emoji-heavy slop that is impossible to read, so I manually clean up the profiling script. It was profiling including the sort, so I refactor the code to move that out of the core function that's profiled.

At this point I realize the runtime of the LINEAR variant is actually quadratic in the input (in fact, it's slower than the QUADRATIC). That's about where I gave up on Claude Code to get me farther.

Fixing the slop

Since Claude had implemented the LINEAR variant I kind of stubbornly stuck to that code for a while and tried to see if I can fix it, but since I didn't know the algorithm at all, I had no idea what was going on. It turns out that correctly implementing it in linear time is subtle, and I had just thrown some kind of rabid approximate LLM agent at it.

So I reached for tools from a more elegant age: I opened the 1987 paper (scanned pdf!). Even there I had "better" tools though than just reading it top to bottom - I asked o3 about the paper, what a totally monotone matrix was, intuition for the lemmas, etc.

Around this point I started to think about the tradeoffs of using an LLM and using my own brain. To be honest the time I spent asking o3 about this paper was nontrivial, and the paper is not that long. Maybe I'm saying this on the other side after figuring it out, but it felt like every time I was confused I would ask o3 some question, read the response, think about it, but because the "escape button" was so convenient, I never thought very hard. Eventually I stopped using o3 and actually thought about it, and it started to make sense.

Then I wrote a (hopefully) correct implemention in pure Python. This one scaled linearly in the number of accesses to the monotone matrix (a proxy for the runtime), so I spent a little bit of time optimizing it to remove unnecessary rows, wrap with nb.njit, etc. And voila! A fast "human crafted" implementation of the linear SMAWK algorithm.

After this I realized that there were many inconsistencies across the different algorithms (e.g. handling early exits, etc.) that made reading the code harder, so I cleaned that up. Claude could have probably done this, but I was tired of prompting by this point.

Was it worth it?

Vibe-coding right now is valuable to me because it helps me understand the limits and capabilities of coding LLMs (agents and tab-style assistants). So the exploration was worth it for that alone.

As a development assistant, the thing that strikes me is that I didn't really fully understand the algorithm until I stopped using LLMs at all, which is probably no surprise to anyone. And I think understanding was a prerequisite to finishing, trying to avoid that telling Claude to test X or Y or think harder was not productive. I was hoping that Claude was smarter than me and could present the idea in a nicer way than the paper via the code, but that wasn't really true.

Incidentally I built a lot of tools (careful reusable prompts) for reading/summarizing papers and guiding Claude towards test driven development, so that work might give returns into the future.

I was reading antirez's blog post and saw this:

[...] when left alone with nontrivial goals they tend to produce fragile code bases that are larger than needed, complex, full of local minima choices, suboptimal in many ways. Moreover they just fail completely when the task at hand is more complex than a given level.

Well, that's my experience. Even for very pure implementation goals, they're just very sloppy. The emojis in particular in my entire codebase drove me nuts (I eventually made a TODO to remove all emojis, which you can )

What's left with binning

shows a way to do the binning with the $k$-medians objective, and that's probably pretty easy to adapt (in practice this is a few lines change). I also want to implement min/max bin widths since practically speaking I need those for work.

A fun plot

As any student of computers knows, the "best" algorithm by big-$\mathcal{O}$ runtime is rarely (never?) the most useful, because of dumb "reality" like cache locality, aligned memory access, register pressure, etc.

In practice the LOGLINEAR algorithm is significantly better than the LINEAR algorithm, and you can barely feel the $\log n$ factor compared to the constant:

Check it out on Github!

Challenge

Someone out there is a better LLM whisperer than me. Do you have a prompt and environment for Claude that one-shots the 1D optimal breaks problem?

References [1] Jenks, George F, "Optimal data classification for choropleth maps", Department of Geographiy, University of Kansas Occasional Paper, 1977 [2] Aggarwal, Alok and Klawe, Maria M. and Moran, Shlomo and Shor, Peter and Wilber, Robert, "Geometric applications of a matrix‐searching algorithm", Algorithmica, 1987 [3] Wu, Xiaolin and Rokne, John, "An O (KN lg N) algorithm for optimum K-level quantization on histograms of N points," in Proceedings of the 17th conference on ACM Annual Computer Science Conference, 1989, pp. 339-343 [4] Wu, Xiaolin, "Optimal quantization by matrix searching", Journal of Algorithms, 1991 [5] Grønlund, Allan and Larsen, Kasper Green and Mathiasen, Alexander and Nielsen, Jesper Sindahl and Schneider, Stefan and Song, Mingzhou, "Fast exact k-means, k-medians and Bregman divergence clustering in 1D", arXiv preprint arXiv:1701.07204, 2017 [6] Song, Mingzhou and Zhong, Hua, "Efficient weighted univariate clustering maps outstanding dysregulated genomic zones in human cancers", Bioinformatics, 2020
  1. It's a little unclear to me whether Fisher's natural breaks refers to the result (i.e. the optimum) or specifically the $\mathcal{O}(k \cdot n \log n)$ algorithm that produces that result (maybe it's the $\mathcal{O}(k \cdot n^2)$ algorithm?).

  2. The best reference I found was this relatively recent paper that (a) writes a nice exposition on the problem for $k$-means, and then derives how to do it for $k$-medians and actually all Bregman divergences: (Grønlund, 2017)

  3. It's actually quite difficult to find the proper citation for these algorithms. Maarten Hilferink of GeoDMS (Object Vision BV) is the author of this page , which shows a proof of the log-linear method. However, it seems like Wu in 1989 and 1991 essentially solved this problem using Aggarwal's method (SMAWK), so I think that's probably the right citation. Please correct me if that's wrong.

  4. I found it pretty surprising that there was a linear algorithm! Usually linear algorithms are pretty simple, i.e. sweep over the data or some transformation of the data (e.g. toposorted DAG in reverse), but this seemed more complex.

https://rachitsingh.com/claude-code/
New Mac setup script

One of the best things about LLMs is that the cost of writing bash or zsh has gone to 0. Here's a way to get started on a new Mac by setting up uv then using it to drive a bunch of macOS automation:

#!/usr/bin/env zsh
set -euo pipefail

# 0. Install uv if missing
if ! command -v uv >/dev/null; then
  echo "→ Installing uv…"
  curl -LsSf https://astral.sh/uv/install.sh | sh
  export PATH="$HOME/.cargo/bin:$PATH"
fi

which can then just run this Python script:

#!/usr/bin/env python
from __future__ import annotations
import shutil, subprocess, sys, os, platform
from pathlib import Path
import inspect

def run(cmd: list[str] | str, sudo: bool = False) -> None:
    if sudo:
        cmd = ["sudo", "-E"] + (cmd if isinstance(cmd, list) else [cmd])
    print("·", " ".join(cmd) if isinstance(cmd, list) else cmd)
    subprocess.run(cmd, check=True, shell=isinstance(cmd, str))

def ensure_in_path(brew_bin: str) -> None:
    incantation = f'eval "$({brew_bin} shellenv)"'
    shellrc = Path.home() / ".zshrc"
    text = shellrc.read_text() if shellrc.exists() else ""
    if incantation not in text.splitlines():
        shellrc.write_text(text + f"\n# Homebrew\n{incantation}\n")

def install_homebrew() -> None:
    if shutil.which("brew"):
        return
    run(
        '/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"',
        sudo=False,
    )
    # apple-silicon vs intel
    brew_bin = "/opt/homebrew/bin/brew" if Path("/opt/homebrew/bin").exists() else "/usr/local/bin/brew"
    ensure_in_path(brew_bin)

def install_mas() -> None:
    run(["brew", "install", "mas"])

def is_mas_installed(app_id: str) -> bool:
    out = subprocess.run(["mas", "list"], capture_output=True, text=True, check=True).stdout
    return any(line.split()[0] == app_id for line in out.splitlines())

def mas_apps() -> None:
    app_id = "1475387142"          # Tailscale
    if is_mas_installed(app_id):
        return                     # already installed – skip prompt & install
    input("Sign in to the Mac App Store now, then press <Enter>...")
    run(["mas", "install", app_id])

def brew_casks() -> None:
    casks = [
        "ghostty", "alt-tab", "zed", "visual-studio-code",
        "slack", "1password", "notion-calendar", "obsidian",
        "linear-linear", "raycast",
    ]
    run(["brew", "install", "--quiet", "--cask"] + casks)

def brew_cli() -> None:
    run(["brew", "install", "starship", "tmux"])

def keyboard_tweaks() -> None:
    cmds = [
        'defaults write NSGlobalDomain NSAutomaticSpellingCorrectionEnabled -bool false',
        'defaults write NSGlobalDomain ApplePressAndHoldEnabled -bool false',
        'defaults write NSGlobalDomain InitialKeyRepeat -int 10',
        'defaults write NSGlobalDomain KeyRepeat -int 1',
    ]
    for c in cmds:
        run(c, sudo=True)

def setup_starship() -> None:
    # need to add this command to ~/.zshrc idempotently if it's not already there
    starship_init_command = 'eval "$(starship init zsh)"'
    zshrc_path = Path.home() / ".zshrc"
    if zshrc_path.exists():
        with zshrc_path.open("r") as f:
            content = f.read()
        if starship_init_command not in content:
            with zshrc_path.open("a") as f:
                f.write(f"\n{starship_init_command}\n")
    else:
        with zshrc_path.open("w") as f:
            f.write(f"# Zsh configuration file\n{starship_init_command}\n")
    command = "starship preset no-nerd-font -o ~/.config/starship.toml"
    if not Path("~/.config/starship.toml").expanduser().exists():
        # make sure the ~/.config folder exists
        config_dir = Path.home() / ".config"
        config_dir.mkdir(parents=True, exist_ok=True)
        run(command, sudo=False)

def customize_zshrc() -> None:
    block = inspect.cleandoc("""
        export HISTFILE="$HOME/.zsh_history"
        export HISTSIZE=1000000000
        export SAVEHIST=1000000000
        setopt EXTENDED_HISTORY
    """) + "\n"
    zshrc = Path.home() / ".zshrc"
    content = zshrc.read_text() if zshrc.exists() else ""
    if "EXTENDED_HISTORY" not in content:
        zshrc.write_text(content + block)


def setup_ghostty() -> None:
    """make a file (and folders) for $HOME/.config/ghostty/config and write to it"""
    config_dir = Path.home() / ".config" / "ghostty"
    config_dir.mkdir(parents=True, exist_ok=True)
    config_file = config_dir / "config"
    config_contents = inspect.cleandoc("""
    copy-on-select = clipboard
    font-family = "SFMono"
    theme = "tokyonight"
    """).strip() + "\n"
    if not config_file.exists():
        config_file.write_text(config_contents)

def set_up_ssh_keys_and_github() -> None:
    """
    Ensure ~/.ssh/id_ed25519 exists, is loaded into the system agent,
    and (optionally) stored in the macOS keychain.
    """
    ssh_dir = Path.home() / ".ssh"
    ssh_dir.mkdir(mode=0o700, exist_ok=True)

    key_path = ssh_dir / "id_ed25519"
    pub_path = key_path.with_suffix(".pub")

    # 1. Generate a key exactly once
    if not key_path.exists():
        email = input("Email address for the new SSH key: ").strip()
        run(["ssh-keygen", "-t", "ed25519", "-f", str(key_path), "-C", email], sudo=False)

    # 2. Check if the key is already in the agent
    def key_loaded() -> bool:
        res = subprocess.run(["ssh-add", "-l"], capture_output=True, text=True)
        return res.returncode == 0 and key_path.name in res.stdout

    if not key_loaded():
        # macOS runs an ssh-agent under launchd by default; just add the key.
        run(["ssh-add", "--apple-use-keychain", str(key_path)], sudo=False)

    # 3. Ensure SSH config adds keys automatically (once)
    cfg = ssh_dir / "config"
    stanza = inspect.cleandoc(f"""
    # ~/.ssh/config
    # This file is automatically generated by setup_computer.py
    # It configures SSH to use the key and add it to the agent.
    Host *
        AddKeysToAgent yes
        UseKeychain yes
        IdentityFile ~/.ssh/id_ed25519
    """).strip() + "\n"
    if not cfg.exists() or "AddKeysToAgent yes" not in cfg.read_text():
        cfg.write_text(cfg.read_text() + stanza if cfg.exists() else stanza, encoding="utf-8")

    # 4. Remind the user to register the key with GitHub
    print(
        f"\n   SSH key ready → copy it with:\n"
        f"    pbcopy < {pub_path}\n"
        "Then add it to https://github.com/settings/keys\n"
    )


def main() -> None:
    if platform.system() != "Darwin":
        sys.exit("This script is intended for macOS.")

    install_homebrew()
    install_mas()
    mas_apps()
    brew_casks()
    brew_cli()
    keyboard_tweaks()
    setup_starship()
    customize_zshrc()
    set_up_ssh_keys_and_github()
    setup_ghostty()

    print("Setup complete. Log out/in (or reboot) for keyboard changes to apply.")

if __name__ == "__main__":
    try:
        main()
    except subprocess.CalledProcessError as exc:
        print(f"Command failed with exit code {exc.returncode}", file=sys.stderr)
        sys.exit(exc.returncode)
https://rachitsingh.com/notes/new-mac-setup-script/
Learning

Gardner was so serious about this learning imperative, so determined that the message would get through, that he wrote the speech out in advance because he wanted “every sentence to hit its target.”

What was his message? “We have to face the fact that most men and women out there in the world of work are more stale than they know, more bored than they would care to admit,” he said. “Boredom is the secret ailment of large-scale organizations. Someone said to me the other day ‘How can I be so bored when I’m so busy?’ I said ‘Let me count the ways.’ Look around you. How many people whom you know well — people even younger than yourselves—are already trapped in fixed attitudes and habits?”

So what is the opposite of boredom, the personal attribute that allows individuals to keep learning, growing, and changing, to escape their fixed attitudes and habits? “Not anything as narrow as ambition,” Gardner told the ambitious McKinsey strategists. “After all, ambition eventually wears out and probably should. But you can keep your zest until the day you die.” He then offered a simple maxim to guide the accomplished leaders in the room. “Be interested,” he urged them. “Everyone wants to be interesting, but the vitalizing thing is to be interested…As the proverb says, ‘It’s what you learn after you know it all that counts.’”

-- Bill Taylor (HBR)

https://rachitsingh.com/learning/
Software to make Macs friendlier to Windows people

macOS has a few UX differences from Windows that can make it difficult to switch from Windows to macOS. Here's some tools and tricks to work around them:

Remap keys

You can change some default keybinds using the Settings app directly under Customize modifier keys (search for this in Settings). If you're more gung-ho, you can use hidutil (easier with this software tool to generate configs).

You can also use third-party software, some of it paid. The best free tool is Karabiner-Elements which is open source. In practice I remapped some things using the modifier keys, and then remapped some other things using Karabiner.

Alt-tab Windows-style

In macOS, alt-tab and command-tab (win-tab for Windows keyboards) do different things: command-tab switches between applications, and alt-tab switches between windows of the same application. This is fundamental because clicking "X" on a macOS window does not exit the application, it merely closes the current window. To fully exit an application, you have to hit command + Q or close it using the application menu on the upper left.

I find this confusing and just want to switch between all my recent windows when using alt-tab, so I installed an app called AltTab.

Search

Windows search (usually win key) is pretty terrible, and Spotlight is quite a bit better. So you should relearn the muscle memory to just hit command + space when you want to search.

Window Snapping

macOS has recently added better window snapping support, but for a long time it was not as good as Windows. You can obviously use a proper tiling window manager, but if you're lazy, I find Rectangle is the simplest way to get the same behavior. I've heard Loop is also pretty good and intuitive but there's slightly more of a learning curve.

Next time I set up a new Mac I'll document what I install (there's some software you can't live without, and some general hygiene stuff like not overwriting system Python) and put it here.

I stopped installing Rectangle in favor of Raycast's window management extension
https://rachitsingh.com/notes/windows-to-macos/
Ghostty

Ghostty is out. It looks and feels great.

Despite being in the Discord server for about a year I didn't get access until today, which is pretty fair1.

In summary: I'm a big fan of Mitchell Hashimoto's work and his decision-making around this project. It seems like a labor of love (which is true for all terminal emulators) and I appreciate that I get all the benefit of that now.

I think in terms of latency-sensitivity I'm probably near the 95th %ile (in that even a small amount of latency when I'm typing causes me to get frustrated). So I can feel the difference between this and iTerm2, though it might still be a placebo.

tip: check your remote .bashrc

I'll write down more as I go, but here's the first: on the remote server I work on, when I logged in with Ghostty, there were no colors. The culprit was a line that looked like this in the ~/.bashrc:

case "$TERM" in
    xterm-color|*-256color) color_prompt=yes;;
esac

since Ghostty advertises itself as xterm-ghostty but can take colors, you should change this to:

case "$TERM" in
    xterm-color|*-256color|xterm-ghostty) color_prompt=yes;;
esac

This is a problem with your default ~/.bashrc, not Ghostty (which correctly informs the terminfo database).

color scheme: nightowl

I use the Sarah Drasner's Night Owl theme in VSCode and everywhere else. I converted it to a Ghostty theme (imperfectly, I'm not a designer) here:

~ ❯ cat ~/.config/ghostty/themes/nightowl
palette = 0=#011628
palette = 1=#EF5350
palette = 2=#22DA6E
palette = 3=#ADDB67
palette = 4=#82AAFF
palette = 5=#C792EA
palette = 6=#21C7A8
palette = 7=#FFFFFF
palette = 8=#575656
palette = 9=#EF5350
palette = 10=#22DA6E
palette = 11=#FFEB95
palette = 12=#82AAFF
palette = 13=#C792EA
palette = 14=#7FDBCA
palette = 15=#FFFFFF
background = 011628
foreground = d6deeb
cursor-color = 7e57c2
selection-background = 5f7e97
selection-foreground = dee4ee
tmux on a remote server

I don't fully understand why this is an issue, but when I SSH into a remote server and try to start tmux, I get the following error:

missing or unsuitable terminal: xterm-ghostty

The solution is to copy the terminfo information to the remote host like this:

infocmp -x | ssh $HOST -- tic -x -

credit for this fix goes to Kovid Goyal's kitty documentation as well as the corresponding Ghostty page

wishlist

I'd love a Windows build that is almost as good as the macOS one. For aspiring programmers out there (I guess pretty ambitious ones that can learn Zig) if you submit PRs for this, you might get a code review from @mitchellh, which is I think pretty valuable. This is the starting point.

  1. I didn't have time to actually contribute anything to the project. It would take a while for me to figure it out, and it wasn't a priority.

https://rachitsingh.com/notes/ghostty/
Don't just disagree, ask why

When disagreeing with someone about a decision or a fact, it's useful to ask someone why rather than just presenting your own view. It reorients the discussion as the team vs the problem, and gives everyone space to acknowledge why they might be wrong.

Curiosity

Ted Lasso is back for a 4th season, which made me think of this moment in the first1. One of my favorite scenes is when Ted plays darts with Rupert. Here's a transcript for the video-weary:

You know Rupert, guys have underestimated my entire life. And for years I never understood why. It used to really bother me. But then one day I was driving my little boy to school and I saw this quote by Walt Whitman and it was painted on the wall there that said "Be curious. Not judgmental."2 I like that. thwack.

So I get back in my car and I'm driving to work and all of sudden it hits me. All them fellas that used to belittle me, not a single one of 'em was curious. You know they thought they had everything figured out, so they judged everything, and they judged everyone. And I realized that their underestimating me, who I was had nothing to do with it. Because if they were curious, they would ask questions.

You know. Questions like, have you played a lot of darts, Ted? thwack. To which I would have answered: Yes sir. Every Sunday afternoon at a sports bar with my father from age 10 til I was 16 when he passed away. Barbecue sauce. thwack.

The scene has a lot of emotional appeal; everyone likes when the main character is secretly more skilled than they're letting on. There's a TV Trope for it: I Am Not Left-Handed, which Ted Lasso even calls out directly:

Ted: Oh, wait a second. I forgot I'm left handed.

Outside of avoiding judgement, being curious is the start of good listening3. Repeating back what someone is saying in your own words is both communicating that you're actively listening and something that you naturally do when you're curious.

One way to disagree

I think there's one further version of this idea, which is helpful when navigating conflict. It's pretty straightforward: don't just disagree, ask why.

There are a lot of places where basic disagreements come up. For example, at trivia, you might think: "I know for sure that Jane Addams was born in 1860" and your friend might be equally confident that she was born in the 1880s.

The least useful way to communicate is to just be confident that the other person is wrong (e.g. "No, she was born in 1860."). Even if you've just finished reading a book on her life, you can do a better job of communicating than just shutting down their idea with a "that's wrong". A lot of people will change their mind if you ask a question like: "Why do you think so? Where did you read that?". If their memory is vague ("e.g. I remember reading about her after the Civil War section") and yours is precise ("I was reading a book last week that mentioned that she was born just before the Civil War in Illinois to a family with 8 children"), in the vast majority of cases they will acknowledge that you're probably right4.

This method is helpful in contexts where the stakes are low (who wants to have a fight about a trivia question?), but also applies to high stakes decisions, as well. By asking for the source of their decision or knowledge, you're communicating that you are a team trying to find the right solution. This isn't a passive, inassertive, or conflict-avoidant way to communicate - it's just acknowledging that they could have a self-consistent explanation that comes to a different conclusion.

Of course, it also helps you avoid looking silly by being dead-certain about something you're wrong about.

Another reason to communicate this way (again, if it's useful to your particular context) is that it gives some people who have different unspoken assumptions a way to communicate. For example, let's say you're hanging out in a left-leaning circle, and someone says something like "unions are bad". You might at first just think that they are more conservative-leaning or at least in favor of business deregulation. But people make mistakes in communicating, especially around unspoken assumptions. If you ask them why they think that instead of moving on with your day, you could learn that they're talking about public sector unions, like police unions, because private unions being good are, for them, a given. You can learn a lot from asking people about things that seem obviously wrong.

You can apply this approach to miniscule pointless conflicts as well. For example, if there's a disagreement about who lost the TV remote, by asking "why" you might be able to piece together a chain of observations which could lead to the remote. You might also discover that your partner is frustrated with your lack of organization, or inability to work with their organization scheme.

Finally, a decision

This isn't a foolproof way to communicate, and eventually someone (you, the group, etc.) will have to make a decision on which way to go. Or, you can agree to disagree. Both are actually much easier when you both understand the merits of each other's arguments. It is a lot easier to agree to "50-50" two reasonable-looking options than to feel like your opinion isn't being considered or you're being dismissed out of turn.

  1. It never quite lived up to that standard again, but it is still a very enjoyable watch through S3.

  2. According to Snopes, the quote isn't by Walt Whitman, but it's an easier story to tell that way.

  3. Ben's blog has some useful examples oriented around being a better listener. The core idea is helpful, but the takeaway is the one you always hear as a chronically solution-oriented person: just listen instead of offering solutions.

  4. And I should just say - this is just one way of communicating. In certain friend groups and certain cultures, it's more accepted to just reject someone's opinion you disagree with - it's expected, even. It can just be really abrasive to encounter that in a group where the norm is different. If someone is arguing in bad faith (i.e. is not really searching for the truth), this obviously doesn't work. But it can work in most peer-to-peer and close manager-to-worker relationships.

https://rachitsingh.com/dont-disagree-ask-why/
Dot products in Rust
This post is a work in progress. I'll update it as I go, and I might be missing very obvious things (or haven't gotten around to it yet). Feel free to shoot me an email if you want to make a comment.

A few months ago I was helping with a Rust-based Llama2 inference project and learned a few things about optimizing CPU SIMD code. One thing I couldn't shake is that codegen in Rust is still pretty bad at the moment, at least for neural network inference.

Here's a comparison of two different dot products, which are written by the portable_simd group and by StackOverflow user Soonts, in Rust and C++ respectively.

#![feature(array_chunks)]
#![feature(slice_as_chunks)]
#![feature(portable_simd)]

use std::simd::num::*;
use std::simd::*;

pub fn dot_prod_simd_5(a: &[f32], b: &[f32]) -> f32 {
    a.array_chunks::<4>()
        .map(|&a| f32x4::from_array(a))
        .zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b)))
        .fold(f32x4::splat(0.), |acc, (a, b)| a.mul_add(b, acc))
        .reduce_sum()
}

fn main() {
    // initialize two large arrays with sin(x/100) and cos(x/100), each with N elements
    // multiply K times
    let N = 1000000;
    let K = 10000;
    let a: Vec<f32> = (0..N).map(|i| ((i as f32) / 100.0).sin()).collect();
    let b: Vec<f32> = (0..N).map(|i| ((i as f32) / 100.0).cos()).collect();

    // compute the dot product of the two arrays
    // do this dot product enough times to reduce fixed costs
    let mut result = 0.0;
    for _ in 0..K {
        result = dot_prod_simd_5(&a, &b);
    }
    println!("result: {}", result);
}

To properly compile this, you have to use a nightly Rust compiler, and set rustflags = ["-Ctarget-cpu=native"] so that codegen properly uses your AVX2-capable machine.

And in C++ (NOTE! not my code, this is from StackOverflow, credit Soonts):

#include <immintrin.h>
#include <vector>
#include <algorithm>
#include <assert.h>
#include <stdint.h>
#include <cmath>
#include <cstdio>

using std::ptrdiff_t;

// CPUs support RAM access like this: "ymmword ptr [rax+64]"
// Using templates with offset int argument to make easier for compiler to emit good code.

// Returns acc + ( p1 * p2 ), for 8 float lanes
template<int offsetRegs>
inline __m256 fma8( __m256 acc, const float* p1, const float* p2 )
{
    constexpr ptrdiff_t lanes = offsetRegs * 8;
    const __m256 a = _mm256_loadu_ps( p1 + lanes );
    const __m256 b = _mm256_loadu_ps( p2 + lanes );
    return _mm256_fmadd_ps( a, b, acc );
}

#ifdef __AVX2__
inline __m256i makeRemainderMask( ptrdiff_t missingLanes )
{
    // Make a mask of 8 bytes
    // These aren't branches, they should compile to conditional moves
    missingLanes = std::max( missingLanes, (ptrdiff_t)0 );
    uint64_t mask = -( missingLanes < 8 );
    mask >>= missingLanes * 8;
    // Sign extend the bytes into int32 lanes in AVX vector
    __m128i tmp = _mm_cvtsi64_si128( (int64_t)mask );
    return _mm256_cvtepi8_epi32( tmp );
}
#else
// Aligned by 64 bytes
// The load will only touch a single cache line, no penalty for unaligned load
static const int alignas( 64 ) s_remainderLoadMask[ 16 ] = {
    -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0 };
inline __m256i makeRemainderMask( ptrdiff_t missingLanes )
{
    // These aren't branches, they compile to conditional moves
    missingLanes = std::max( missingLanes, (ptrdiff_t)0 );
    missingLanes = std::min( missingLanes, (ptrdiff_t)8 );
    // Unaligned load from a constant array
    const int* rsi = &s_remainderLoadMask[ missingLanes ];
    return _mm256_loadu_si256( ( const __m256i* )rsi );
}
#endif

// Same as fma8(), load conditionally using the mask
// When the mask has all bits set, an equivalent of fma8(), but 1 instruction longer
// When the mask is a zero vector, the function won't load anything, will return `acc`
template<int offsetRegs>
inline __m256 fma8rem( __m256 acc, const float* p1, const float* p2, ptrdiff_t rem )
{
    constexpr ptrdiff_t lanes = offsetRegs * 8;
    // Generate the mask for conditional loads
    // The implementation depends on whether AVX2 is enabled with compiler switches
    const __m256i mask = makeRemainderMask( ( 8 + lanes ) - rem );
    // These conditional load instructions produce zeros for the masked out lanes
    const __m256 a = _mm256_maskload_ps( p1 + lanes, mask );
    const __m256 b = _mm256_maskload_ps( p2 + lanes, mask );
    return _mm256_fmadd_ps( a, b, acc );
}

// Compute dot product of float vectors, using 8-wide FMA instructions
float dotProductFma( const std::vector<float>& a, const std::vector<float>& b )
{
    assert( a.size() == b.size() );
    const size_t length = a.size();
    if( length == 0 )
        return 0.0f;

    const float* p1 = a.data();
    const float* p2 = b.data();
    // Compute length of the remainder;
    // We want a remainder of length [ 1 .. 32 ] instead of [ 0 .. 31 ]
    const ptrdiff_t rem = ( ( length - 1 ) % 32 ) + 1;
    const float* const p1End = p1 + length - rem;

    // Initialize accumulators with zeros
    __m256 dot0 = _mm256_setzero_ps();
    __m256 dot1 = _mm256_setzero_ps();
    __m256 dot2 = _mm256_setzero_ps();
    __m256 dot3 = _mm256_setzero_ps();

    // Process the majority of the data.
    // The code uses FMA instructions to multiply + accumulate, consuming 32 values per loop iteration.
    // Unrolling manually for 2 reasons:
    // 1. To reduce data dependencies. With a single register, every loop iteration would depend on the previous result.
    // 2. Unrolled code checks for exit condition 4x less often, therefore more CPU cycles spent computing useful stuff.
    while( p1 < p1End )
    {
        dot0 = fma8<0>( dot0, p1, p2 );
        dot1 = fma8<1>( dot1, p1, p2 );
        dot2 = fma8<2>( dot2, p1, p2 );
        dot3 = fma8<3>( dot3, p1, p2 );
        p1 += 32;
        p2 += 32;
    }

    // Handle the last, possibly incomplete batch of length [ 1 .. 32 ]
    // To save multiple branches, we load that entire batch with `vmaskmovps` conditional loads
    // On modern CPUs, the performance of such loads is pretty close to normal full vector loads
    dot0 = fma8rem<0>( dot0, p1, p2, rem );
    dot1 = fma8rem<1>( dot1, p1, p2, rem );
    dot2 = fma8rem<2>( dot2, p1, p2, rem );
    dot3 = fma8rem<3>( dot3, p1, p2, rem );

    // Add 32 values into 8
    dot0 = _mm256_add_ps( dot0, dot2 );
    dot1 = _mm256_add_ps( dot1, dot3 );
    dot0 = _mm256_add_ps( dot0, dot1 );
    // Add 8 values into 4
    __m128 r4 = _mm_add_ps( _mm256_castps256_ps128( dot0 ),
        _mm256_extractf128_ps( dot0, 1 ) );
    // Add 4 values into 2
    r4 = _mm_add_ps( r4, _mm_movehl_ps( r4, r4 ) );
    // Add 2 lower values into the scalar result
    r4 = _mm_add_ss( r4, _mm_movehdup_ps( r4 ) );

    // Return the lowest lane of the result vector.
    // The intrinsic below compiles into noop, modern compilers return floats in the lowest lane of xmm0 register.
    return _mm_cvtss_f32( r4 );
}

int main(int argc, char** argv) {
    int N = 1000000;
    int K = 10000;
    std::vector<float> a(N);
    std::vector<float> b(N);
    for (int i = 0; i < N; i++) {
        a[i] = sin(i/100.0);
        b[i] = cos(i/100.0);
    }
    float result = 0.0;
    for (int i = 0; i < K; i++) {
        result = dotProductFma(a, b);
    }
    std::printf("result: %.7f\n", result);
    return 0;
}

I compiled this with g++ -O3 -march=native -o main.

As far as I can tell, these pieces of code are functionally identical. There are enough loops that creating the arrays for the first time shouldn't be a problem.

They also produce slightly different results, which might be from different float accumulation patterns.

~/proj/profile/rust_dp main ?33 ❯ ./target/release/rust_dp
result: 4.5250244
~/proj/profile/cpp_dp main ?33 ❯ ./main
result: 4.5243621

and for good measure using clang++:

~/proj/profile/cpp_dp main ?33 ❯ ./main
result: 4.5243621
Benchmarks
~/proj/profile/rust_dp main ?33 ❯ hyperfine --warmup 3 "target/release/rust_dp"
Benchmark 1: target/release/rust_dp
  Time (mean ± σ):      2.261 s ±  0.006 s    [User: 2.234 s, System: 0.006 s]
  Range (min … max):    2.255 s …  2.273 s    10 runs
~/proj/profile/cpp_dp main ?33 ❯ hyperfine --warmup 3 "./main"
Benchmark 1: ./main
  Time (mean ± σ):     752.5 ms ±  19.3 ms    [User: 726.2 ms, System: 3.1 ms]
  Range (min … max):   724.9 ms … 785.9 ms    10 runs

and wow here's clang++:

~/proj/profile/cpp_dp main ?33 ❯ hyperfine --warmup 3 "./main"
Benchmark 1: ./main
  Time (mean ± σ):     233.4 ms ±   3.2 ms    [User: 207.6 ms, System: 3.3 ms]
  Range (min … max):   228.7 ms … 238.4 ms    11 runs
Ideas
  1. Rust code could still be generating bounds checks.
  2. We aren't unrolling optimally above and using all the registers.
  3. We need -ffast-math to get some small microoptimizations.
https://rachitsingh.com/rust-dot-product/
How to add an accordion

It turns out it's really easy to add a click-to-expand element (also a dropdown or accordion element) in Markdown + HTML without any JavaScript, since it's part of the spec.

This:

<details>
    <summary>
        Click here to expand
    </summary>
    Surprise!
</details>

turns into this:

Click here to expand Surprise!

Of course you can make this fancier with CSS and whatnot. But the core action is built into HTML, so there's no need for JavaScript.

If you want a Zola shortcode, you can use this example one:

templates/shortcodes/summary.html:

<details>
    <summary>{{ title }}</summary>
    {{ body }}
</details>

and

{% summary(title="Click here") %}
Hello, world!
{% end %}

which turns into this:

Click here Hello, world!

Found here: gist

https://rachitsingh.com/notes/how-to-add-a-dropdown/
Setting up CUDA on WSL2 in 2023

Setting up CUDA on WSL2 probably costs you a little bit of performance but gives you a lot of flexibility - you're essentially running a VM with (some) Microsoft support. Here's the best source: Enable NVIDIA CUDA in WSL.

Luckily in the past year or so NVIDIA has started supporting this path, so you're (kind of) in good hands. A lot of advice will be mixed between Ubuntu and WSL, though, so I'm writing down my experience with installing CUDA on WSL2 today.

The below is a bit of a bloated install, so you might be able to get away with fewer things (for example TensorRT is pretty optional).

Install drivers

First, you'll need to get an NVIDIA driver (in Windows) for your card. I got the "NVIDIA Game Ready" driver from NVIDIA's website directly. Note that this is not installing an NVIDIA driver for your WSL (i.e. Linux) installation. As far as I can tell, this process uses your Windows driver directly (which makes sense, having two drivers for the same physical device running seems like a recipe for disaster).

I chose:

  • No GEForce Experience
  • HD Audio Driver
  • PhysX

though of course you can make other choices. Once this is finished, you will see a bunch of files in C:\Windows\System32\lxss\lib\:

~ ❯ ls -l /mnt/c/Windows/System32/lxss/lib/
-r-xr-xr-x 1 singhrac singhrac 10188744 Aug 15 00:23 libcudadebugger.so.1
-r-xr-xr-x 1 singhrac singhrac   154088 Aug 15 00:23 libcuda.so
-r-xr-xr-x 1 singhrac singhrac   154088 Aug 15 00:23 libcuda.so.1
-r-xr-xr-x 1 singhrac singhrac   154088 Aug 15 00:23 libcuda.so.1.1
-r-xr-xr-x 3 singhrac singhrac  5401440 Aug 11 11:05 libd3d12core.so
-r-xr-xr-x 3 singhrac singhrac   800296 Aug 11 11:05 libd3d12.so
-r-xr-xr-x 2 singhrac singhrac   827904 Jun  5  2021 libdxcore.so
-r-xr-xr-x 1 singhrac singhrac 10820792 Aug 15 00:23 libnvcuvid.so
-r-xr-xr-x 1 singhrac singhrac 10820792 Aug 15 00:23 libnvcuvid.so.1
-r-xr-xr-x 1 singhrac singhrac 37086552 Aug 15 00:23 libnvdxdlkernels.so
-r-xr-xr-x 1 singhrac singhrac   551528 Aug 15 00:23 libnvidia-encode.so
-r-xr-xr-x 1 singhrac singhrac   551528 Aug 15 00:23 libnvidia-encode.so.1
-r-xr-xr-x 1 singhrac singhrac   233832 Aug 15 00:23 libnvidia-ml.so.1
-r-xr-xr-x 1 singhrac singhrac   362960 Aug 15 00:23 libnvidia-opticalflow.so
-r-xr-xr-x 1 singhrac singhrac   362960 Aug 15 00:23 libnvidia-opticalflow.so.1
-r-xr-xr-x 1 singhrac singhrac    68560 Aug 15 00:23 libnvoptix.so.1
-r-xr-xr-x 1 singhrac singhrac 83265296 Aug 15 00:23 libnvwgf2umx.so
-r-xr-xr-x 1 singhrac singhrac   678064 Aug 15 00:23 nvidia-smi

and the same set of files in /usr/lib/wsl/lib/. NOTE! This is the location that the CUDA Runtime will live. However, you also need the CUDA Toolkit (probably). Note that nvidia-smi will now work if you add /usr/lib/wsl/lib/ to your path:

~ ❯ nvidia-smi
Thu Aug 24 11:17:06 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.103                Driver Version: 537.13       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2070        On  | 00000000:07:00.0  On |                  N/A |
|  0%   51C    P5              20W / 175W |    573MiB /  8192MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

However the "CUDA Version" at the top is the max supported CUDA Toolkit version that you can use1. At this point in time (20230824) PyTorch doesn't really support CUDA 12.2, so it might be a huge pain in the ass if you use that. I recommend 11.8, the currently max supported version on the install page. I'm sure this will change shortly.

At this point, you probably want to add /usr/lib/wsl/lib/ to your LD_LIBRARY_PATH, which will let programs find the .so files there. Put this in your ~/.bashrc or ~/.zshrc:

export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH

If you run ldconfig -p | grep libcuda, this will now show those libraries (you might have to exec zsh or restart your shell).

Install CUDA Toolkit

It is really important to listen to NVIDIA's instructions here: the default CUDA Toolkit comes with a driver, which will overwrite the above set (i.e. piggy-backing off Windows driver). So it's important to install the WSL-specific CUDA Toolkit, which doesn't have that driver. This is also where it's really easy to install the latest CUDA Toolkit (i.e. 12.2), which might cause issues for you. I didn't run into too many the first time, but I think it's easier just to install the older toolkit.

Go here: CUDA Toolkit 11.8. You can find other versions in the "Archive of Previous CUDA Releases" link.

You want to choose: Linux > x86_64 > WSL-Ubuntu > 2.0 > (your choice of installer type)

I personally find it easier to deal with deb installation because then things like dpkg will help you manage the installation, but everyone's experience is different. In any case, you want to follow those instructions, making sure you have WSL-Ubuntu selected. If you use the deb (network) install, it might try to install 12.2 because that's the most recent version. I think you can work around this by doing sudo apt-get install cuda-toolkit-11-8, but I used the deb (local) version instead.

I personally ran:

~ ❯ wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin
~ ❯ sudo mv cuda-wsl-ubuntu.pin /etc/apt/preferences.d/cuda-repository-pin-600
~ ❯ wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda-repo-wsl-ubuntu-11-8-local_11.8.0-1_amd64.deb
~ ❯ sudo dpkg -i cuda-repo-wsl-ubuntu-11-8-local_11.8.0-1_amd64.deb
~ ❯ sudo cp /var/cuda-repo-wsl-ubuntu-11-8-local/cuda-*-keyring.gpg /usr/share/keyrings/
~ ❯ sudo apt-get update
~ ❯ sudo apt-get -y install cuda-toolkit-11-8

Note the change in the last line - I wanted to pin the specific version. This installs a lot of stuff for you: nvprof, nvcc, nvprune, libcublas-dev, libcufft, libcusparse.

~ ❯ ls -l /usr/local
total 36
drwxr-xr-x  2 root root 4096 Aug 24 11:48 bin
lrwxrwxrwx  1 root root   22 Aug 24 11:48 cuda -> /etc/alternatives/cuda
lrwxrwxrwx  1 root root   25 Aug 24 11:48 cuda-11 -> /etc/alternatives/cuda-11
drwxr-xr-x 15 root root 4096 Aug 24 11:48 cuda-11.8
drwxr-xr-x  2 root root 4096 Feb 10  2023 etc
drwxr-xr-x  2 root root 4096 Feb 10  2023 games
drwxr-xr-x  2 root root 4096 Feb 10  2023 include
drwxr-xr-x  3 root root 4096 Feb 10  2023 lib
lrwxrwxrwx  1 root root    9 Feb 10  2023 man -> share/man
drwxr-xr-x  2 root root 4096 Feb 10  2023 sbin
drwxr-xr-x  6 root root 4096 Mar  7 17:10 share
drwxr-xr-x  2 root root 4096 Feb 10  2023 src

I'm not entirely sure why there's two separate cuda and cuda-11 folders, and they don't seem to be symlinks to each other. This is a bit concerning, but I'm not going to worry about it too much. My ldconfig -p | grep libcublas points to the /usr/local/cuda variant, but when I check nvcc there, it says the version of the toolkit is 11.8.

Note, however, that it doesn't install cuDNN or TensorRT (aka nvinfer). So let's do that next.

cuDNN and TensorRT

NVIDIA now has a deep learning installation guide here: NVIDIA cuDNN. I personally grabbed cuDNN v8.9.4 for CUDA 11.x, "Local Intaller for Ubuntu22.04 x86_64 (Deb)", and then the rest of the instructions from that linked page worked fine.

Just in case it's unclear, the runtime library installation command was sudo apt-get install libcudnn8=8.9.4.25-1+cuda11.8 (after you've set up the local repository). I didn't install the code samples. After running sudo ldconfig, I was able to find cuDNN using ldconfig -p | grep libcudnn.

For TensorRT, the instructions are a bit confusing but you can basically do the same thing:

~ ❯ sudo dpkg -i nv-tensorrt-local-repo-ubuntu2204-8.6.1-cuda-11.8_1.0-1_amd64.deb
~ ❯ sudo cp /var/nv-tensorrt-local-repo-ubuntu2204-8.6.1-cuda-11.8/nv-tensorrt-local-0628887B-keyring.gpg /usr/share/keyrings/

The latter command is the output of the former, so you don't really need to worry about version information - just download the right .deb from the link, and then copy and paste the keyring command.

From here, I installed the full runtime, but not the Python packages - I like to do that inside virtualenvs (since, for example, you might want a different base Python version).

Python packages

Now you can install torch like you always wanted (in a virtualenv, right?):

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Make sure you check it works!

~ ❯ python -c "import torch; print(torch.cuda.is_available())"
True
Uninstall

I had to temporarily uninstall as part of messing everything up, and it wasn't too bad. Mostly I purged some apt packages using stuff like:

sudo apt-get --purge remove "cuda*"

and swapped cuda for nvidia and nsight.

This gets rid of most things, but as mentioned above, this will keep nvidia-smi around, since that's from the C:\Windows\System32\lxss\lib\ install, which is kept in /usr/lib/wsl/lib/. To remove those, you need to uninstall the NVIDIA graphics driver from your Windows machine (note for me this caused Alacritty to break, but Powershell let me access WSL just fine).

  1. As far as can tell! I'm putting this out there and hoping Cunningham's law will get the right answer out, but that depends on you to correct me. Please email me at (fullname) [@] outlook.com.

https://rachitsingh.com/notes/wsl-cuda/
(Not) fast dot products via SIMD

Lately I've been tinkering with optimizing Sasha's llama2.rs, a fast Rust port of Karpathy's llama2.c. It takes advantage of Rust nightly's portable_simd, which allows you to emit AVX2 or AVX512 instructions using a relatively clean set of abstractions, and also run inference on a quantized llama2, so it's pretty fast.

One of the core loops looks like this:

let mask = (1 << BITS) - 1; // BITS is 4
let elems_per_i32 = 32 / BITS;
let ipg: usize = GROUPSIZE / 32 * BITS; // GROUPSIZE is 128, so this is 16 (left-associative)
let mask_4bits = i32x8::splat(mask);
let shift_right = i32x8::from_array([0, 4, 8, 12, 16, 20, 24, 28]);
...
// Do K at a time
let zero = f32x8::splat(0.0);
let qzeros = &self.qzeros[oi / elems_per_i32]; // self is the neural net block this is part of
let out_elem = oi % elems_per_i32;
let qweight = self.qweight[oi].chunks_exact(ipg);

let collect = self.scales[oi]
    .into_iter()
    .zip(qweight)
    .enumerate()
    .map(|(group, (scale, qweight))| {
        let qz = ((qzeros[group] >> (BITS * out_elem)) & mask) + 1;
        let scale_simd = f32x8::splat(scale);
        let zero_simd = i32x8::splat(qz);
        let in_pos = group * GROUPSIZE;
        let xs = x[in_pos..in_pos + GROUPSIZE].chunks_exact(8);
        qweight
            .iter()
            .zip(xs)
            .map(|(v, x)| {
                //Extract v into 8 chunks
                let x = f32x8::from_slice(x);
                let num_simd = i32x8::splat(*v);
                let qw: i32x8 = (num_simd >> shift_right) & mask_4bits;
                let combine: f32x8 = (qw - zero_simd).cast::<f32>();
                let weight: f32x8 = scale_simd * combine;
                weight * x
            })
            .fold(zero, |x, y| x + y)
    })
    .fold(zero, |x, y| x + y);
*o = collect.reduce_sum(); // output reference to write to

Here, the f32x8 and i32x8 primitives are aliases for the std::simd types that denote a SIMD vector with 8 lanes of f32 and i32 respectively (256 bits each). I have a 5700X, a Zen 3 CPU with AVX2 support but not AVX512.

When I profiled this, I saw that a lot of the time was being spent in the fold on the inside (i.e. the sum over results from mapping over qweight).

Ok, so this is basically a SIMD dot-product between x[in_pos..in_pos + GROUPSIZE] and whatever the weight becomes. My limited SIMD experience is based on writing CUDA kernels, so naturally I thought that this might be better if I rewrote this to use the fma instructions that have been around for the past few years.

I rewrote the interior slightly:

// Do K at a time
let zero = f32x8::splat(0.0);
let qzeros = &self.qzeros[oi / elems_per_i32];
let out_elem = oi % elems_per_i32;
let qweight = self.qweight[oi].chunks_exact(ipg);

let collect = self.scales[oi]
    .into_iter()
    .zip(qweight)
    .enumerate()
    .map(|(group, (scale, qweight))| {
        let qz = ((qzeros[group] >> (BITS * out_elem)) & mask) + 1;
        let scale_simd = f32x8::splat(scale);
        let qzero_simd = i32x8::splat(qz);
        let in_pos = group * GROUPSIZE;
        let xs = x[in_pos..in_pos + GROUPSIZE]
            .chunks_exact(8)
            .map(f32x8::from_slice);
        let q_op = qweight
            .iter()
            .map(|v| {
                //Extract v into 8 chunks
                let num_simd = i32x8::splat(*v);
                let qw: i32x8 = (num_simd >> shift_right) & mask_4bits;
                (qw - qzero_simd).cast::<f32>()
            });
        xs.zip(q_op).fold(zero, |acc, (x, y)| {
            x.mul_add(y, acc)
        }) * scale_simd
    })
    .fold(zero, |x, y| x + y);
*o = collect.reduce_sum();

However, this tanked performance. I'm working on getting a proper hyperfine benchmark up and running1, but it's safe to say that the performance is at least 50% worse. I thought issuing 1 instructon (fma) per element of the vector (well, N/8 since I have 8 lanes) would be faster than issuing 2 instructions (mul and add).

The reason, as best as I can understand it (not SIMD expert, yet!) is that in the above code, we're bottlenecked on instruction latency, not throughput. Apparently the computation

result = weight[0] * x[0] + weight[1] * x[1] + ... + x[n] * weight[n]

can be run highly out-of-order because it doesn't matter which multiplication happens first, or technically in which order the results are added2. That's why I see the add in the profiling output, because that's the instruction that ends up being where the CPU is waiting for the result of all the individual multiplications to finish.

In contrast:

result = fma(fma(x[2], weight[2], fma(x[1], weight[1], fma(x[0], weight[0]), 0)))

is highly order dependent, so the execution is much slower (also, I think in this case fma latency is 4 cycles vs 3 cycles for add and mul, which makes things much slower). Compilers are pretty smart! And modern out-of-order execution is pretty impressive.

Here's the people who helped me figure this out:

Is vfmadd132pd slow on AMD Zen 3 architecture?

Latency and throughput

  1. The full program is hard to benchmark because a ton of the initial time is spent mmaping the huge weights into memory, so I think it depends a lot on what the rest of my machine is doing (I think I have a fairly slow SSD at the moment).

  2. I'm not entirely sure about this, actually, since I think Rust doesn't enable -ffast-math in release builds (I'm not sure how to, at the moment), so I think this operation still ends up being non-associative.

https://rachitsingh.com/notes/fma-dependency/
Peter Reilly

I recently heard that one of my high school mentors, Dr. Peter Reilly, passed away last week at the age of 64.

He was a supremely patient, kind, and energizing mentor. I think one of the things that strikes me most is how at the time I didn't understand how rare that is, or how lucky we were to have his help.

When I met him, I had been looking for research projects to work on with my friend Vivek, and he was the first to really take us seriously. Many, many professors turned us away with responses like "come back in graduate school" or "I guess I can put you to use if you want to stack lead bricks". With 10 more years of experience, I can understand now how reasonable a response that is. But Dr. Reilly took the time to explain his work to us, and gave us a small project that we could tackle: translating an ion trap program from SL to Lua. Later this experience with Lua proved surprisingly useful when writing Lua Torch code (remember that?).

What I remember most vividly from this time is that occasionally I would get a confusing result, I would send Dr. Reilly an email, and then I'd poke my head in his office door and ask if he had a second, and pretty much always he would say "Sure!".

He taught me a lot of concrete skills like how to do real scientific work, communicate results, write a paper, and come up with ideas. I regret a lot that I didn't communicate to him enough what an impact he had on my life. With more perspective now I can understand why he and the other mentors in my life have taken a gamble on me: to pay it forward. I hope I can do as good a job.

Rest in peace, Dr. Reilly.

https://rachitsingh.com/peter-reilly/
Anti-shibboleths

In the few areas where I have a clue about what people are thinking about, there's words that make me wonder whether this person is either clueless or assuming I'm clueless.

One example is that in New York City, if you ask someone what their favorite type of food is, and they respond with "Asian", it sets off a few alarm bells. I'm not a very sophisticated restaurant-goer, and even I know that a description like that is so broad as to be almost meaningless.

A more prominent example today is using the word "AI". I think there are smart people who use the word precisely (e.g. as in "AI risk") but usually people using it at this moment in time mean something like "RLHF-tuned large language models" or even specifically ChatGPT. I see this most on Twitter, though I've encountered people in real life who will tell me about what "AI can do now". I'm still deciding if this is a useful heuristic, but I change how I view someone's idea based on whether they say LLM or AI.

I explicitly don't want to say that you need to understand a lot of math to use these tools in creative and useful ways. Specialization is important and not all researchers want to do product design - but it's very useful in any field to be precise with terms.

This post is now outdated. As with all things, technology adoption moves faster than we can imagine, and now it feels standard to see 'AI' embedded in products in ways I didn't expect. I think now (and probably before) using the term 'LLM' is more confusing than just saying it's 'AI'. Unsurprisingly, I didn't realize that end users don't care about how the text is generated. You can disregard this post entirely.
https://rachitsingh.com/notes/anti-shibboleths/
Dealing with Pandas's nullable float dtypes

Pandas added support for nullable float32 and float64 datatypes in the past few years (Float32 and Float64 respectively)1, but there's a lot of footguns, so it feels kind of bolted on. One example is that while arrays with this dtype can have both np.nan and pd.NA (aka None) as values, Series.isna() only catches the latter not the former. This is a known bug, but there's no fix yet and it might take until Pandas 3.0 before one comes.

You can jump straight to the solution.

Here's an example provided by @cmillani:

import numpy as np
import pandas as pd
df = pd.DataFrame({'a': [5, 0], 'b': [np.nan, 12]})
df = df.astype('float64')
df['c'] = df['a'] * np.inf
df.isna()

       a      b      c
0  False   True  False
1  False  False   True

Note that casting to Float64 does change np.nans to pd.NA:

df = df.astype('Float64')
df

     a     b     c
0  5.0  <NA>   inf
1  0.0  12.0  <NA>

However, now df.isna() doesn't catch newly introduced np.nans:

df['c'] = df['a'] * np.inf
df.isna()

       a      b      c
0  False   True  False
1  False  False  False
Fix

In the meantime, you can fix it for yourself using the following monkey patch:

import numpy as np
from pandas.core.arrays.floating import FloatingArray
FloatingArray.oldisna = FloatingArray.isna
def newisna(self: FloatingArray) -> np.ndarray:
    return np.isnan(self._data) | self._mask.copy()
FloatingArray.isna = newisna
df.isna()

       a      b      c
0  False   True  False
1  False  False   True

You'll want to patch fillna, too, probably:

from pandas.core.arrays.masked import BaseMaskedArrayT, validate_fillna_kwargs, is_array_like, missing
FloatingArray.oldfillna = FloatingArray.fillna
def newfillna(
    self: BaseMaskedArrayT, value=None, method=None, limit=None
) -> BaseMaskedArrayT:
    value, method = validate_fillna_kwargs(value, method)

    mask = (self._mask | np.isnan(self._data)).copy()

    if is_array_like(value):
        if len(value) != len(self):
            raise ValueError(
                f"Length of 'value' does not match. Got ({len(value)}) "
                f" expected {len(self)}"
            )
        value = value[mask]

    if mask.any():
        if method is not None:
            func = missing.get_fill_func(method, ndim=self.ndim)
            npvalues = self._data.copy().T
            new_mask = mask.T
            func(npvalues, limit=limit, mask=new_mask)
            return type(self)(npvalues.T, new_mask.T)
        else:
            # fill with value
            new_values = self.copy()
            new_values[mask] = value
    else:
        new_values = self.copy()
    return new_values
FloatingArray.fillna = newfillna

There might be more efficient ways (though I think fillna seems to do an extra copy no matter what).

  1. I find this notation pretty confusing, and think the DE Shaw variant of this much more clear (float32? and float64?, along with the same for every other dtype).

https://rachitsingh.com/notes/pandas-null/
Don't say it's easy

When teaching someone how to do something, don't say that "it's easy".

If it were easy for them, they wouldn't be asking you for help.

Even though you're trying to reassure them that it isn't actually that hard, you're starting from different places. For example, take calculus. Once you've learned some calculus, you have the mental framework, and you know what pitfalls to avoid. Someone encountering it for the first time is going to encounter problems you can't even imagine.

That's one of the things great teachers can do: mentally model what problems their students are going to run into, and nudge them to avoid those pitfalls.

https://rachitsingh.com/notes/dont-say-its-easy/
JetBlue WiFi

This one is for the poor soul who is attempting to push their code to Github right now via Jetblue's excellent ViaSat wifi, but confused about why it's hanging:

It's because JetBlue blocks connections over port 22 (i.e. SSH).

It was previously possibly to add the HTTPS url of your repo as an alt remote, but this no longer works because Github has disabled push on HTTPS auth. Github has a tutorial on using SSH over the HTTPS port, but this unfortunately doesn't work on JetBlue's wifi.

There's some relatively complicated things you can do by setting up a proxy of your own (on a free Google Cloud or Digital Ocean instance, for example) that will use the HTTPS CONNECT protocol to forward your (HTTPS) connection to the server you want to SSH to, two examples of which are here and here.

However, in practice, the easiest thing for me was to just connect to a VPN, which removed all restrictions.

https://rachitsingh.com/notes/jetblue-wifi/
WSL tips
Last updated: November 1st, 2023

Windows Subsystem for Linux (WSL) is my main programming environment, after giving up Windows in 2015. It is surprisingly good.

Shrink your virtual HD

The virtual hard drive in WSL seems to grow without bound (they grow as needed, but won't release space even if you delete data). After shutting down your WSL instances (i.e. wsl --shutdown), run this in PowerShell (as an administrator):

Optimize-VHD -Path $VHD_PATH -Mode full

where the $VHD_PATH is usually something like C:\Users\$USERNAME\AppData\Local\Packages\$DISTRO\LocalState\ext4.vhdx

There's an alternative method using diskpart described here which I've successfully used before.

Separately, there's a convenient package called wslcompact that will run Powershell scripts for you to automate this process. I've only used it to show info, but most reports seem to indicate it works well.

Use WSL VSCode plugin

I was a long-time holdout on Visual Studio Code (VSCode), but I've completely switched from vim except for on-server editing. The things that made me switch were (a) Pylance, (b) rust-analyzer, and (c) Copilot. rust-analyzer is available inside vim as well, but the integration is less convenient, at least for me (I like being able to hover over variables and see the inferred type, it is extremely helpful).

You can use the WSL extension to open a folder for editing.

If you have VSCode installed, you can use code . from inside WSL to open a folder inside VSCode. This will also reset rust-analyzer if it ends up in a bad state.

Install CUDA

See this post for advice on how to install CUDA.

https://rachitsingh.com/notes/wsl-tips/
Write it up as you go

This is probably an idea familiar to most researchers (and really anyone who produces written output for their job), but writing it up as you go has huge advantages over writing it up at the end. This isn't just a time-management tactic, but a thinking tool.

I learned this from feedback during my first job, but the short summary is: if you write down your results as you go, you don't yet know what's relevant or important. That leaves organic room for growth as you revisit the ideas later.

For example, let's say I'm investigating the correlation between two features A and B in my dataset. I realize that the relationship changes depending on whether C is True or False. I might choose to look at the C=True case for a while, then revisit C=False later. If I don't write down this decision, I can forget to look at the C=False case until I'm writing the writeup... at which point I have to go back and do more research.

All I have left to do is write it up

The problem is that writing forwards for a reader forces you to reexamine all the decisions you made. The reader probably will. You have little control over how the reader thinks; this is by design, since the reader is usually error-checking and validating that what you're saying is true.

So, write it up as you go. I don't mean produce a polished draft of the introduction first, then do research. I mean keep notes about all your decisions and what you learn. In the above case, you might do something as terse as this:

corr(A, B):
    - ran in "exp 20230511.ipynb"
    - depends on C
        - checking C=True case in "exp 20230511 - C_true.ipynb"

This isn't meant for anyone to read but yourself. But write it up as you go.

https://rachitsingh.com/notes/writing-it-up-as-you-go/
Date your notebooks

Checked in code and commits are highly static artifacts intended for reliability. Engineers are told to make small commits so that they are easy to reason about (e.g., tests should pass before and after the commit).

Notebooks, on the other hand, are intended to be inherently flexible. They are a space for thinking and exploring. I probably share 5% of the notebooks I make (and those might need a different scheme than described here). As a result, I often delete cells to make space for new ideas.

I don't think you need a hierarchical notebook organization scheme, or a formal check-in process. Just making a new notebook every day or so, and backing up often, is enough for most people to regain their bearings.

But making a new one every day, with the date at the beginning (for easy sorting) is a simple tool to avoid getting overwhelmed when you have hundreds (or thousands) of files in your "notebooks" folder. It's a simple idea, but often overlooked.

I prefer the YYYYMMDD - {project name} - {comments}.ipynb format.

https://rachitsingh.com/notes/date-your-notebooks/
RTX

I completely failed to achieve my goal of writing something non-technical at least once a month, so here we're trying not to fail with abandon. As a sad result of my WSL2 setup dying, I've had to reconfigure my desktop Linux experience from scratch again.

As a result I've run into Jeff Dickey's rtx, a software version manager written in Rust. Anyone who manages multiple projects with different sets of dependencies1 knows the pain of setting up Python environments. So far I've been using pyenv for that, which is a pretty good default.

rtx, like asdf before it, generalizes version management for software besides just Python versions, which is important when you're scared of npm (like me). Most importantly, it improves it by speeding everything up and removing shims. The entire configuration is downloading rtx, and adding the following to your .zshrc:

eval "$(rtx activate zsh)"

Here's the result:

~ ❯ rtx list
⏵  python 3.10.10            (set by ~/.tool-versions)
⏵  rust 1.67.1               (set by ~/.tool-versions)
~ ❯ cd ~/proj/amplify
~/proj/amplify ❯ rtx list
⏵  python 3.10.10            (set by ~/proj/amplify/.rtx.toml)
⏵  rust 1.67.1               (set by ~/.tool-versions)
~/proj/amplify ❯ which python
/home/singhrac/proj/amplify/.venv/bin/python

Notice the virtualenv in there? That's from moving to a directory with a .rtx.toml in it:

[tools]
python = {version = '3.10', virtualenv='.venv'}

One of the most reassuring parts of the experience is that Jeff is actively working on fixing bugs and adding features pretty much all the time.

  1. Ever try testing your Pytorch code on multiple numpy/torch/Python versions? Good luck if you're not using Nox for automated testing, but even with that set up, you'll eventually need to debug some version conflict eventually.

https://rachitsingh.com/notes/rtx/
Blogging

I spent an unfortunate amount of time writing and re-writing my recent blog post about Jupyter notebooks, mostly double checking that I liked the overall style and tone. After I put it up on the shiny new website, I hesitated and finally submitted it to Hacker News, the message board most likely to care.

I put my computer away and focused on brunch with my friends, and a few hours later logged on to see what had happened: nothing at all. It received no upvotes at all, and as far as I could see never reached anywhere near the front page.

I had two reactions. First, I grimaced that no one cared about my writing. Then I felt free: if no one was reading my writing then I should feel free to hit publish more often.

After a few hours I’m realizing that that’s not true - I cared. The whole exercise is for me. Here’s to an unburdened and optimistic 2023.

https://rachitsingh.com/notes/blogging/
Collaborating effectively with Jupyter notebooks

I spend roughly half of my programming time working in Jupyter notebooks1. Some is exploring data and building models, and some is experimenting when writing new code. I like notebooks - they make exploration much easier. For data science work, visualization with notebooks is just easier than what came before.

However, opinions of notebook-style programming vary wildly, from love to mild dislike to hate. There are many tropes about notebook users, but the most common one is that a data scientist opens up a notebook, writes spaghetti Python code that kind of produces a model, and then hands it over to a "real engineer" to turn into something that can be run in production. The most common reaction is that the data scientist loves Jupyter, and the engineer hates it.

This dynamic has been picked over many times, by a wide variety of data science, data engineering, and machine learning teams 2. Many startups have been built around trying to fix this problem.

Instead of writing an ethnography of ML engineering3, here's a list of principles I follow for creating Jupyter notebooks for handoff, as someone who probably straddles the line between data scientist and developer.

1. Kill the kernel and rerun all cells

If you're sending a notebook to someone else, the bare minimum is that you should restart your kernel and verify that the results look right. This almost goes without saying.

The bare minimum is restarting your kernel.

Often you'll realize that you've deleted a cell that was important, or ran cells out of order. This is on you, the notebook author, to fix. It's very common to avoid this step because the notebook takes a long time to run. Further down I include some helpful advice on how to conquer this fear.

2. Don't mutate. Don't mutate. Don't mutate. (across cell boundaries)

Don't mutate variables defined in an earlier cell if you can avoid it.

One common anti-pattern for notebook users is to do something like this:

# cell 1
df = pd.read_csv("<fn>").set_index('timestamp')

...

# cell 25
df = df.reset_index().join(...).set_index('date')

...

# cell 146
df.plot()

Remember that most readers of notebooks skim out of sheer necessity: it is tiring reading someone else's code. Keeping a mental model of the variable df and its index is extremely difficult.

On the other hand, assuming variables are totally immutable is a great way to use up all your memory, since most variables you define will be in the global scope and not garbage collected.

df_a = pd.read_csv("<fn>")
df_b = df_a.groupby(['a']).response.mean()
...

My proposed solution:

  • Reuse variable names inside a cell so that intermediate state can be garbage collected
  • Don't mutate the variable outside of the cell that it's defined in
  • Ideally "export" one variable per cell, and make it obvious (i.e. the last line of a cell is something like exported_var = ...).

For example (with a contrived example):

import seaborn as sns

iris = sns.load_dataset('iris')
mean = iris.groupby("species").petal_width.mean()
mean += iris.groupby("species").sepal_width.mean()
std = iris.groupby("species").std()[['sepal_width', 'petal_width']].mean(axis=1)
species_z = mean / std
3. Use Markdown to separate regions

Make liberal use of Markdown headers to separate and organize your work.

One of the major advantages of notebooks over regular Python code is that you can include Markdown to write explanations of the content. I'd highly recommend writing headers (e.g. ### Loading data) to make it easier for your readers. You can also fold sections by clicking on the caret on the left side of the header.

This becomes much easier once you learn the muscle memory for a few shortcuts:

  • Esc moves you from editing to command mode
  • j/k or up/down to select a different cell in command mode
  • b to insert a cell below the currently selected one
  • a to insert a cell above the currently selected one
  • m to convert a code cell into a Markdown cell

Pretty quickly the muscle memory of Esc -> b -> m -> Enter becomes automatic.

4. Cache expensive computation to disk (automatically)

Make it easier to kill your kernel by caching large datasets to disk.

A common anti-pattern for notebook users is to do some slow and memory-intensive computation, and then hold the resulting data in memory and be scared of killing the kernel and losing that data.

The solution: write it to disk!

However, most notebook users avoid this not because writing to disk is infeasible, but because the tools they have for doing this are poor.

Here's a common example of loading a lot of data from disk to memory, joining it against another dataframe, then selecting a subset of it:

df_list = []
for i in range(100):
    df = pd.read_parquet(f"filename_{i}.parquet")
    df = df.join(my_other_dataset.set_index('ts'), on='ts')
    df = df[df.ts > pd.Timestamp('20200101')]
    df_list.append(df)

df = pd.concat(df_list)

Here are a few issues with this code:

  1. If it's slow, you won't want to run it again, so you'll avoid killing your kernel.
  2. df_list is still hanging around, holding pointers to each individual dataframe, so those can't be garbage collected

Instead, use a decorator to cache this computation to disk and compute it just once:

@cache_to_disk
def compute_my_dataframe():
    df_list = []
    for i in range(100):
        df = pd.read_parquet(f"filename_{i}.parquet")
        df = df.join(my_other_dataset.set_index('ts'), on='ts')
        df = df[df.ts > pd.Timestamp('20200101')]
        df_list.append(df)
    return pd.concat(df_list)

df = compute_my_dataframe()

One issue with this approach is that if you're not careful, you can leak other data into this function implicitly; above we're leaking the variable my_other_dataset. If you change how that variable is defined, you can lead to unreproducible results. In practice, though, I find this simpler to reason about than most people think. You can also delete your cache before rerunning it for sharing.

Using a decorator means that you don't have to manage the disk writing and reading yourself, which means the barrier to doing this drops significantly. There are many useful decorators, but here's a simple one I wrote just for you:

from functools import wraps
import pickle
import pandas as pd

def cache_to_disk(func):
    """
    A decorator for functions with no arguments.
    """
    cache_path = pathlib.Path(".")
    func_cache_path = (
        cache_path /
        func.__module__.replace(".", "/") /
        func.__name__
    ).with_suffix('.cache')
    if not func_cache_path.parent.exists():
        func_cache_path.parent.mkdir(parents=True)

    @wraps(func)
    def wrapped():
        if func_cache_path.exists():
            try:
                ret = pd.read_parquet(func_cache_path)
            except:
                with open(func_cache_path, 'rb') as fh:
                    ret = pickle.load(fh)
            return ret
        else:
            ret = func()
            if isinstance(ret, pd.DataFrame):
                ret.to_parquet(func_cache_path, compression='zstd')
            else:
                with open(func_cache_path, 'wb') as fh:
                    pickle.dump(ret, fh)
            return ret
    return wrapped
5. Use assertions

Assert anything you confirm "by inspection" while working with the data:

assert len(df.species.unique()) == len(ALL_SPECIES)
assert df.index.name == 'blah'
assert iris.dtypes['sepal_length'] == np.float64

They help the reader build their mental model of what is guaranteed and checked, and what is not.

6. Don't import code from other notebooks (or copy and paste)

Instead of copying and pasting code (or worse), set up a local Python module to import. It's easier than you think.

Importing code from other notebooks is extremely weird. Since most notebooks are written like one long main script, this is like importing from one script to another.

Copying and pasting a huge block of useful snippets is common (e.g. methods to load data, clean it, plotting, etc). The issue is that it's very easy to copy and paste irrelevant code around, and your reader will have trouble understanding what's used and what's not. Also this can take time and adds to visual clutter.

One alternative is that if you find some code becoming useful, you can add it to your own personal little library or package, and import that code everywhere.

I find that most people find this pretty intimidating or a waste of time, because they assume it requires checking code in or deploying it. That tends to be what developers recommend as well. However, you can set up your repo in "develop" mode and any changes you make to the file will be immediately importable.

Here's one way to set this up when not working in a cloud environment:

  1. Create an empty directory in an easy-to-edit place (e.g. /home/rachit/utils/)
  2. Inside that folder, create a minimal Python package as described here: minimal Python package setup
  3. Making sure that your python is whatever is used in your Jupyter kernel, run python setup.py develop in that directory. You might also be able to use pip install -e .; I haven't tested it.

Now, whenever you make any changes to that module, you can restart your kernel and get the updated version of that code. Using the example module described in the article:

from funniest import joke
7. Avoid using generic variable names

Generic variable names confuse readers and can also lead to subtle bugs because of the default-global context.

Like df, like I've done above. It can be really convenient, because it's easier to type, but Jupyter does support autocomplete so you can just hit TAB to complete a variable name.

The issue with using generic variable names is that (1) it makes it harder to read through and understand the code, and (2) common variable names are often overwritten, which violates the "Don't mutate" rule but also confuses most people.

Here's an example:

# cell 1
df = pd.read_csv("...")

# cell 25
dfs = []
for i in range(25):
    df = pd.read_parquet(f"filename_{i}.parquet")
    dfs.append(df)
other_dataset = pd.concat(dfs)

# cell 45
df.plot()

Here, the last df is actually the final instance of df in the loop on cell 25, i.e. whatever is inside filename_24.parquet. Python loops modify the local namespace, so you might not even notice. It's better just to name the first variable something more descriptive.

8. Don't use notebooks when you mean to use an IDE

This isn't advice on how to collaborate with notebooks.

Real IDEs (tm) have many useful features that notebooks (at least currently) don't have:

  • High quality type inference / Intellisense / autocomplete 4
  • Easier to productionize code by splitting code into files and modules and tests
  • Access to (many) plugins that haven't yet made it over to JupyterLab, etc.

My general rule of thumb: if you're writing production code, it's easier to do it in an IDE or a text editor like Vim. If you're creating a reproducible chart, or training a small model for research use, it's better to do it in a notebook, if you can.

Summary

Have empathy for the person who is going to read your notebook; reading someone else's code (especially in a critical setting) is exhausting and hard work.

Thanks

Thanks to Stefan Gramatovici and Prastuti Singh for reading earlier drafts of this post.

  1. The remaining half is spent between VSCode and Vim; roughly 98% of the code I write for work is in Python.

  2. If you haven't already seen / read Joel Grus's talk from JupyterCon 2018, it's pretty funny. I would recommend.

  3. Some of the most common complaints: poor (unreadable) coding style due to the code being written for a single author, reliance on data without provenance (i.e. I downloaded a file to my local home directory and am now loading it), complex . If I get one thing across to data scientists who collaborate with engineers: the "exploratory" programming style is not production code, because production code needs to be readable by people with a range of context on what the code does.

  4. I'm a big fan of Pylance inside VSCode; I realize it's a bit concerning that this is closed source, but at a high level I'm ok with it. I use many closed source products from Microsoft daily.

https://rachitsingh.com/collaborating-jupyter/
Setting up a new server

These are things I always forget but need to be really comfortable working in a remote server. I try to be distro-agnostic when I can.

  1. Set up history options in .bashrc:
    HISTCONTROL=ignoreboth:erasedups
    HISTSIZE=100000
    HISTFILESIZE=100000
    shopt -s histappend
    shopt -s checkwinsize
    PROMPT_COMMAND="history -a; history -c; history -r; $PROMPT_COMMAND"
    
    case "$TERM" in
        xterm-color|*-256color) color_prompt=yes;;
    esac
    
    alias ls='ls --color=auto'
    alias ll='ls -alF'
    alias la='ls -A'
    alias l='ls -CF'
    
    alias ag=rg
    
  2. Install cargo and all the Rust tools I love:
    curl https://sh.rustup.rs -sSf | sh
    cargo install ripgrep fd-find procs du-dust
    
  3. Install build-essential or the equivalent:
    sudo apt-get install build-essential git # for Ubuntu
    
    sudo yum groupinstall "Development Tools" # for AWS Linux
    sudo yum install git
    
  4. Setup pyenv:
    git clone https://github.com/pyenv/pyenv.git ~/.pyenv
    cd ~/.pyenv && src/configure && make -C src
    echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc
    echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc
    echo 'eval "$(pyenv init -)"' >> ~/.bashrc
    exec "$SHELL"
    
    # install build requirements
    sudo apt update; sudo apt install make build-essential libssl-dev zlib1g-dev \
    libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm \
    libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev
    
    # install an example python version
    pyenv install 3.10.6
    
    git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv
    echo 'eval "$(pyenv virtualenv-init -)"' >> ~/.bashrc
    exec "$SHELL"
    
    pyenv virtualenv 3.10.6 prod
    
https://rachitsingh.com/notes/server-setup/
Installing docker & docker compose on Amazon ARM Linux

If you're like me and have spun up a t4g.small during AWS's free trial period, you might want to run docker and docker compose. However, documentation on this is pretty sparse.

docker-compose vs docker compose

Apparently, docker compose is the new hotness and should be used going forward. I'm sure there are many changes under the hood but the headliner is that it's written in Go and docker-compose was written in Python (apparently).

docker

Based on this gist:

sudo yum install docker
sudo service docker start
sudo usermod -a -G docker ec2-user

Auto start:

systemctl enable docker.service
docker compose

This is the one you want, not docker-compose.

From here

sudo mkdir -p /usr/local/lib/docker/cli-plugins/
sudo curl -SL https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m) -o /usr/local/lib/docker/cli-plugins/docker-compose
sudo chmod +x /usr/local/lib/docker/cli-plugins/docker-compose

Question & comments: feel free to send them to fullname [at] outlook.com.

https://rachitsingh.com/notes/aws-linux-docker-compose/
The Langevin Equation

This post covers the Langevin equation, a stochastic differential equation that models the dynamics of particles in Brownian motion1. This covers the ideas used in this reference due to Lennart Sjögren.

Langevin Equation

In 1907 Einstein published a paper that derived a macroscopic quantity $D$, the diffusion constant, with microscopic quantities:

$$D = \frac{k_BT}{6\pi\eta a}$$

where $\eta$ is the viscosity of the liquid and $a$ is the radius of the particle. In 1908 Langevin derived the same result using Stoke's law to note that the drag on a particle is $\gamma = 6\pi \eta a$, so that we can write the equations of motion as:

$$m\frac{d^2x}{dt^2} = -\gamma \frac{dx}{dt} + \xi$$

for some noisy force $\xi$. The argument, which is given in the linked paper, relies on the fact that we can multiply by $x$, and take the average over all the particles. This gives us a term $\overline{\xi x}$, which Langevin claimed is 0. This 'averaging' lets him work in variance-space, which leads to a result about the second moment of the group of particles: $\overline{x^2} - \overline{x_0^2} = 2Dt$, using Einstein's $D$, which is the the same as the definition of $D$. The averaging procedure is a bit sketchy, and some people have noted the handwave. Our goal is to put this on a surer footing.

Stochastic Differential Equations

An ordinary differential equation might look like $y'(t) = f(y(t))$, or $dy(t) = f(y(t)) dt$ in differential form. If we want to model a process with noise, we might add a Brownian motion increment $\sigma dB_t$ to get:

$$ \begin{aligned} dy_t &= f(y(t))dt + \sigma dB_t \\ y_t &= \int_0^t f(y(t)) dt + \sigma B_t \end{aligned} $$

This is a regular integral because $\sigma$ doesn't depend on time, but the end result is a stochastic process (i.e. a collection of random variables indexed here by time). We can of course generalize this as much as we'd like, including letting $\sigma$ be a function of $y$ or $t$, which would necessitate a stochastic integral, but luckily we don't need to just yet (that'll come later).

In our case, we can write the Langevin equation as:

$$ \begin{aligned} \frac{dx_t}{dt} &= v_t \\ \frac{dv_t}{dt} &= -\frac{\gamma}{m}v_t + \frac{1}{m}\xi_t \\ \implies dv_t &= -\frac{\gamma}{m}v_t dt + \frac{1}{m}\xi_t dt \end{aligned} $$

In application, we can note the following facts: $\mathbb{E}[\xi(t)] = 0$, and $\mathbb{E}[\xi(t_1)\xi(t_2)] = g\delta(t_1 - t_2)$. The first statement says that the force has 0 mean. The second says that it is totally uncorrelated with some variance $g$ at any given point. This assumption is realistic because the particles considered are being hit by other particles many billions of times a second. Finally, we can make the (fairly strong) assumption that the force applied at each time comes from a Gaussian distribution with the moments implied by the previous facts2.

Then, $\xi(t)$ satisfies the definition of a Gaussian white noise. Then we can confidently approach its integral as a Wiener process (or Brownian motion process). We can also get this from Donsker's theorem I think, though I'm not sure about the details. Letting $U_t = \int_0^t \xi_s ds$, we finally have the SDE for an Ornstein-Uhlenbeck process3:

$$dv_t = -\frac{\gamma}{m} v_t dt + \frac{1}{m}dU_t$$

which we can integrate by applying Ito's formula in the following way:

$$ \begin{aligned} de^{\gamma t/m}v_t &= \left(\frac{\gamma}{m}e^{\gamma t/m}v_t dt + e^{\gamma t/m}dv_t\right) \\ &= \frac{1}{m}e^{\gamma t/m} dU_t \\ \implies e^{\gamma t/m}v_t &= v_0 + \frac{1}{m}\int_0^t e^{\gamma s/m} dU_s \\ v_t &= v_0e^{-\gamma t/m} + \frac{1}{m}\int_0^t e^{-\gamma (t - s)/m} dU_s \\ \end{aligned} $$

This is in some sense the 'solution' to the Langevin equation, but our goal is to rederive the results in the Langevin paper almost de novo. Clearly from this result the expectation is just $v_0e^{-\gamma t/m}$ (which decays to 0), since $U_s$ the Wiener process has mean 0. However, for the second moment (from which we can compute the variance) we need to apply Ito's isometry:

$$ \begin{aligned} \mathbb{E}[v_t^2] &= e^{-\gamma 2t/m}v_0^2 + g\mathbb{E}\left[\left(\frac{1}{m}\int_0^t e^{-(t - s)\gamma/m} dU_s\right)^2\right] \\ &= e^{-2\gamma t/m}v_0^2 + \frac{g}{m^2}\mathbb{E}\left[\int_0^t e^{-2(t - s)\gamma/m} ds\right] \\ &= e^{-2\gamma t/m}v_0^2 + \frac{g}{2\gamma m}\left[e^{-2(t - s)\gamma/m}\right]_0^t \\ &= e^{-2\gamma t/m}v_0^2 + \frac{g}{2\gamma m}\left(1 - e^{-2\gamma t/m}\right) \\ \end{aligned} $$

Now, as $t \to \infty$, we'd expect $\mathbb{E}[v_t^2] = k_BT/m$ (from equipartition), and the limit of the above result is $g/2\gamma m$, so we can conclude that:

$$g = 2\gamma k_B T$$

Finally, we can use this to derive the dynamics of the particle itself (as a stochastic process, of course):

$$ \begin{aligned} x_t &= \int_0^t v_s ds \\ &= \int_0^t \left(v_0e^{-\gamma s/m} + \frac{1}{m}\int_0^s e^{-\gamma (s - u)/m} dU_u\right) ds \\ &= x_0 + \frac{\gamma}{m}v_0\left[1 - e^{-\gamma t/m}\right] + \frac{1}{m} \int_0^t \left(\int_u^t e^{-\gamma(s - u) / m} ds\right) dU_u \\ &= x_0 + \frac{\gamma}{m}v_0\left[1 - e^{-\gamma t/m}\right] + \frac{1}{\gamma} \int_0^t \left[1 - e^{-\gamma(t - u)/m}\right] dU_u \end{aligned} $$

and again apply Ito's isometry to get the second moment (here, I drop the leading terms which go to 0 as $t \to \infty$):

$$ \begin{aligned} \mathbb{E}[(x_t - x_0)^2] &= \left(\text{terms that go to 0} \ldots \right) + \frac{1}{\gamma^2}\mathbb{E}\left[\left(\int_0^t \left[1 - e^{-(t - u)\gamma/m}\right] dU_u\right)^2\right] \\ &= \frac{g}{\gamma^2} \int_0^t\left[1 - e^{-\gamma(t - s)/m}\right]^2 ds \\ &= \left(\text{terms that go to 0} \ldots \right) + \frac{g}{\gamma^2}\left[t - \frac{m}{\gamma}\left(1 - e^{-t\gamma/m}\right)\right] \ \end{aligned} $$

So, in the long run, the variance grows as $\mathbb{E}[(x_t - x_0)^2] = \frac{2k_BT}{\gamma} t$. Comparing this to the diffusion equation which says that this should grow as $2Dt$, we get that

$$D = \frac{k_BT}{\gamma}$$

which is Einstein's result.

References

Le Gall, Jean-François. Brownian motion, martingales, and stochastic calculus. Vol. 274. Heidelberg: Springer, 2016.

Sjögren, Lennart. Stochastic Processes lecture notes ch. 6: Brownian Motion: Langevin Equation. Retrieved from here

Next: The Feynman-Kac Formula which definitely has applications in statistics.

  1. This is pretty related to stochastic gradient Langevin dynamics (see here) I think. I don't think I know that paper well enough or the surrounding background in (Neal, 2010) to comment intelligently yet, but something I hope to get to soon. My understanding is that Langevin dynamics are essentially the above system, but with a driving force (maybe the gradient of the loss?).

  2. One attempt to justify that the distribution of $\xi(t)$ at each time must be Gaussian is the following: we let $dt$ be large enough that hundreds of collisions still happen. No matter the distribution that that comes from, since the variables are i.i.d. we can apply the regular central limit theorem (CLT) to show that the overall force converges in distribution to a Gaussian.

  3. Actually, Ornstein developed these methods in order to formalize Langevin's arguments. I've linked a review paper from 1930, but the first version was published in 1917, I think.

https://rachitsingh.com/ornstein-uhlenbeck/
Persistence Length
td { padding: 5px; font-family: monospace; font-size: 1.25rem; } th { text-align: center; padding: 0px 5px; } th.left_column { text-align: right; } figure { margin: 0px 20px; max-width: 50rem; } img[src*="#smaller"] { width: 65%; margin: auto; margin-bottom: 15px; }

In class we recently discussed the simplified elastic rod model for polymers, which assumes that polymers can be modeled as an inextensible rod, i.e. that the length of the rod doesn't change, and that the twist of the polymer is ignorable (possibly because the polymer is joined by single bonds). I want to give a short proof of the statement that the tangent-tangent correlation function, i.e. the autocorrelation of the polymer at some distance $x$, is exponentially decaying 1.

Model setup

We let the polymer be of length $L$, with $s \in [0, L]$ denoting the position of interest along the polymer. At each position $s$, we have a unit tangent vector $\mathbf{t}(s)$. In the simplified elastic rod model, it's important to understand the forces at work: first, given a particular configuration of the polymer, there's an energy corresponding to it, from the potential energy of the bending of the polymer. We argue that this energy is quadratic in the amount of the bend since locally it's stiff, and stiff rods essentially follow Hooke's law. We can measure 'the amount of bend' as $d\mathbf{t}(s)/ds$, so in this model we have:

$$dE = \frac{1}{2}k_BT \left(A \left|\frac{d\mathbf{t}(s)}{ds}\right|^2\right) ds$$

here, $A$ is an introduced constant with units of length. For any configuration $\omega$ of the polymer we can integrate this to get the total potential energy $E(\omega)$.

Second, there's entropy, which comes from the temperature. For long ropes this is negligible, but for polymers a few nanometers wide in solution this is very significant. The model uses the standard method of applying the Boltzmann distribution - i.e. for each configuration $\omega$ with potential energy $E(\omega)$ the probability of being found in that configuration is:

$$p(\omega) \propto e^{-E(\omega)/k_BT}$$

In fact, that's all we really need to set this up.

Argument

The most important thing here is to draw a very clear picture. We're going to consider 3 positions $A, B, C$ along the polymer, in that order, and consider the tangent-tangent correlation between $A$ and $C$, which is $\mathbf{t}(A) \cdot \mathbf{t}(C)$. The key step here is to set up an axis where $\mathbf{t}(B)$ is on the $z$-axis, and $\mathbf{t}(A)$ is aligned with the $x$-axis, (i.e. a spherical coordinate system):

We let $\theta_A, \theta_C$ denote the angle between $\mathbf{t}(B)$ and $\mathbf{t}(A), \mathbf{t}(C)$ respectively. $\phi_C$ is the angle between $\mathbf{t}(C)$ and the $x$-axis.

Then, from the regular inner product in this coordinate system we can see that:

$$\mathbf{t}(A) \cdot \mathbf{t}(C) = \underbrace{\cos \theta_A \cos \theta_C}_{z-\text{axis}} + \underbrace{\sin \theta_A \sin \theta_C \cos \phi_C}_{x-\text{axis}} + \underbrace{0}_{y-\text{axis}}$$

We want to find the expected value of this quantity (which is a random variable in $A, C$, since we fix $B$ under the Boltzmann distribution:

$$\langle\mathbf{t}(A) \cdot \mathbf{t}(C)\rangle = \int_{0}^{2\pi} \int_0^\pi \int_0^{\pi} \left(\cos \theta_A \cos \theta_C + \sin \theta_A \sin \theta_C \cos \phi_C\right) e^{-E(\theta_A, \theta_C)/k_BT} d\theta_A d\theta_C d\phi_C$$

The argument Nelson makes is that since the energy isn't dependent on $\phi_C$ (since the only thing that matters is the bend off $\mathbf{t}(B)$), we can factor the second term in the integral as:

$$\int_{0}^{\pi} \int_0^\pi \sin \theta_A \sin \theta_C \cdot e^{-E(\theta_A, \theta_C)/k_BT} \left(\int_0^{2\pi} \cos \phi_C d\phi_C\right) d\theta_A d\theta_C $$

But the integral on the inside is 0, so that term vanishes. Next, we should note that $E(\theta_A, \theta_C)$ can be decomposed as $E(\theta_A) + E(\theta_B)$. This comes from the form of the integral as an integration with respect to $\operatorname{d}s$. Essentially:

$$\int_A^C dE = \int_A^B dE + \int_B^C dE$$

So we can write:

$$ \begin{aligned} \langle\mathbf{t}(A) \cdot \mathbf{t}(C)\rangle &= \int_0^\pi \int_0^\pi \cos \theta_A \cos \theta_C e^{-E(\theta_A, \theta_C)/k_BT} d\theta_C d\theta_A \\ &= \int_0^\pi \cos \theta_A e^{-E(\theta_A)/k_BT} d\theta_A \int_0^\pi \cos \theta_C e^{-E(\theta_A, \theta_C)/k_BT} d\theta_C \\ &= \langle\mathbf{t}(A) \cdot \mathbf{t}(B)\rangle \cdot \langle\mathbf{t}(B) \cdot \mathbf{t}(C)\rangle \end{aligned} $$

From here, it's we can apply the Cauchy function equation (the correlation is bounded) to get that the correlation should be exponentially decaying. The computation of that scalar factor is well-handled both by Nelson and Physical Biology of the Cell, so I won't reproduce it here.

tldr

The factorization of correlation comes implicitly from:

  1. the fact that in 3 dimensions we only care about the deviation from the central axis, with $$\cos(\theta_A + \theta_C) = \cos(\theta_A)\cos(\theta_C) + (\text{stuff that dies}\ldots)$$, and
  2. the energy splits because of independence of the chain, and that means the probability distribution factors.

References

Nelson, P. (2008). Biological physics. New York: WH Freeman.

  1. I found two clear treatments: one in Nelson's Biophysics (9.1.3 Track 2), which has been on my shelf for 4 years but I still haven't read through the whole thing, and some lecture notes from EPFL. The latter is more rigorous (and Josh, another student, presented it in class), but I found it kind of annoying that there wasn't a simpler proof. Nelson has a good perspective but doesn't include any drawings, so I've covered it here.

https://rachitsingh.com/persistence-length/
A few favorite papers of 2017

This isn't an exhaustive list, and I will inevitably forget some papers. I'll keep updating as a remember, and will probably expand some of the background/contribution sections as I have time, so that they're more accessible.

Breaking the Softmax Bottleneck: A High-Rank RNN Language Model [link]

Background: Language models and NLP tasks almost always use a softmax to compute a distribution over the vocabulary, and usually this is computed as $\sigma(\mathbf{W}\mathbf{h})$, where $\mathbf{h}$ is a $d$-dimensional context vector from a previous layer, and $\mathbf{W} \in \mathbb{R}^{M \times d}$ is a word embedding, letting $M$ be the vocabulary size. In language modeling, the model might generate a sequence of contexts $\mathbf{h}_1, \ldots, \mathbf{h}_N$ 1. Taking the probability vector for each word and stacking it gives a probability matrix $\mathbf{A}$, generated by the model 2.

Contribution: The authors show that the probability matrix is inherent limited in rank by the dimension of the embedding. This is bad, because there are a set of probability distributions that we can't even represent using our model! So no matter how good we are at the architecture beforehand (using all sorts of fancy dropouts / weight drops / etc.), we are limited by this "softmax bottleneck". Note that $d \approx 300$, and the maximum rank of this matrix is $M$, which is usually 2 orders of magnitude larger. The (simple) solution? Just blending together the softmaxes from $k$ different embeddings. They train this new model, reducing $d$ to make the number of parameters comparable, and get a ~4 PPL reduction from the previous SOTA on PTB + WikiText.

Why I like this: It's so simple! And yet, somewhat surprising at the same time. It's the kind of thing I worried about briefly (in an abstract way, without the framework they've provided) when I first learned about softmaxes, but then kind of brushed away. And yet they've shown that just adding simple correction will make models better. It's transferable to translation / summarization models, and I think whenever I think about a probabilistic model I'll think briefly about what kinds of probability distributions are representable.

Understanding Black-box Predictions via Influence Functions [link]

Background: This paper falls under the umbrella of 'interpretability', which is a heated subject in the research world at the moment.

Contribution: This paper (1) looks at earlier work (Cook & Weisberg, 1982) explaining how the parameters of a learned model might change slightly (i.e., the gradient of the parameters) when we change the weighting of datapoints. It uses this via the chain rule to look at the gradient of the test loss with respect to upweight/downweighting a data point. Then, it finds (somewhat) fast algorithms for computing this loss, and applies it to some (small) neural networks. Essentially, it finds a method for 'walking in the space of images' to find a training image that can make your model screw up. Think: insert a random datapoint into the online MNIST files, and suddenly your model is horrible. Note that Yann LeCun doesn't use https on that website.

Why I like this: I'm a big fan of this paper because it uses some fundamental research on influence functions in a new way. Also, the framework it proposes, that of adversarial training images, is definitely interesting. There are some caveats, of course, with work like this: (non)convexity, complicated models (for which the Hessian is not easy!), efficiency, and the authors treat those issues head-on, rather than brushing them aside. That being said, it's unclear whether it'll work without some more algorithmic refinement (a problem I tackled a bit this semester but didn't make much progress on). I also have an implementation of this work in PyTorch that I hope to open source soon, since it's a little nontrivial and probably of use. This is now possible since PyTorch has some support for automatic HVP calculations.

Unbiased Markov chain Monte Carlo with couplings [link]

I'm a bit biased because Pierre taught an excellent inference class this spring which I took, but it was definitely interesting to me.

Background: Because of burn-in, MCMC samples are always biased. The bias can be reduced by extending burn-in, of course, so this might not be a practical concern, but actually burn-in can be a significant contributor (imagine you have to take a few samples from a large set of distributions, like $p(x_i | z_i)$ after you've already sampled $z_i$ in an ad-hoc way).

Contribution: This paper shows a method (via coupled chains) that shows how to eliminate bias in MCMC samples.

It's very rare to see a work both push SOTA on a benchmark, and be faster at the same time.

Gumbel Softmax / Concrete Distribution [link] & [link]

I really mean this section to be about these papers and all of the followups, which I think are somewhat natural extensions of the same line of work. See REBAR, which makes this estimator unbiased by using it as a control variate for the REINFORCE estimator, and RELAX, which takes things one step further by extending it to the RL paradigm, where you can only sample one action rather than a weighted mix of actions.

Background: Variational inference is a powerful technique (see Jeffrey's blog post here), and the research community has made a lot of progress in SVI recently. However, for discrete latent variables, the best we can do is use black-box variational inference, which is interesting, but very hard to implement and tune 3.

Contribution: These papers give a solution by introducing a relaxation of discrete distributions to distributions over the simplex, and also gives them a reparametrization gradient. Actually, adding Gumbel noise is a common trick in the economics and ranking literature (see: McFadden 1974, Mattson et. al. 2011), but the real innovation here is realizing that taking a softmax over logits + Gumbel noise is reparametrizable, and useful as an approximation to the real argmax (which is the Gumbel max trick). The papers themselves have good explanations, and Eric Jang has a nice blog post about how to do it in practice. Actually I think this connection is worth exploring, so I'll go over it in a blog post later.

Why I like this: It's really useful! Especially with the two follow up works, this is actually a very useful alternative to BBVI (just see the results from our paper), and makes it much easier to build and train generative models with discrete latent states. For about a month after reading this paper whenever I heard about a generative process I would think: "can this be modeled using a VAE?". Also, I think the papers are well written.

On the Quantitative Analysis of Decoder-Based Generative Models [link]

I'm not sure if this paper counts as "this year", but it was in ICLR 2017 so I'll count it.

Background: It's difficult to evaluate some generative models. Suppose that there's a latent vector $\mathbf{z}$, and we generate $\mathbf{x}$ stepwise, so that $\mathbf{z} \sim p(\mathbf{z})$, and then $\mathbf{x} \sim p(\mathbf{x} | \mathbf{z})$. Then for a hold-out element $\mathbf{x}_{\operatorname{test}}$, we have to integrate out all values of $\mathbf{z}$! This can be expensive, since integration usually means Monte-Carlo sampling. We have some alternatives, like the ELBO or IWAE bounds for VAEs, but even there it is a lower bound, and we're not confident that the ELBO is even asymptotically consistent (the IWAE metric is, but we don't know how fast). For GANs we only seem to have a kernel density estimator (KDE).

Contribution: First, the authors use bidirectional Monte Carlo (BDMC) to upper bound the log marginal, and show that annealed importance sampling is accurate enough to use to evaluate the quality of other estimates. Then, they use it to evaluate kernel density estimators and IWAE on a number of datasets, and also compare VAEs and GANs (!). They essentially show that IWAE is not quite good enough yet, and KDE is very bad in high dimensional spaces, as expected (it's not even consistent). Also, they show that VAEs are significantly better on evaluation than GANs (since they actually fit to a log-marginal estimator).

Why I like this: It ties together the promises of a lot of other papers, and has good experimental results, which are cleanly interpreted. I don't think it's foundational, but it gives good motivation for trying to come up with good evaluation metrics for GANs, and lets people know that their concerns about "lower bounds not being good enough" are warranted.

Frequentist Consistency of Variational Bayes

Background: Variational bayes / variational inference is well established as a fast alternative to Markov-chain Monte Carlo (MCMC), but we have few theoretical results about it. In particular people are concerned that even discounting the fact that the variational posterior might be misspecified, we don't know that in the large-data limit, the variational posterior is centered at the 'best possible estimate' of the latent variables in some sense, and is asymptotically normal.

Contribution: Consider just the case of estimating the posterior distributions over $\theta$, the model's parameters (so, shared global latent variables) [^4]. We're being pretty Bayesian here in this sense, so we have a posterior $q(\theta)$. Consider that given $\theta$ , we can pick the best possible $q(z)$ (in maximizing the ELBO), and pick the best possible $\theta$ in this way, and call it the variational frequentist estimate (VFE). Then, assuming that this VFE is consistent, the paper shows (1) the VB posterior $q^*(\theta)$ converges to the member of the family that minimizes the KL divergence with a normal centered at the VFE / true parameter, with some variance, and (2) it is consistent. And then they do applications!

Why I liked it: While it doesn't attempt to tackle all of variatonal inference in one sitting, it is well written and the style of proof reflects how people prove the Bernstein von Mises theorem, so I assume it has the same level of impact (I have a bad sense). Note that there are liberal technical conditions throughout, but they're usually not egregious (see: Robert Nickl's lecture notes on the BVM here, which gives some insight). Also, I appreciate these kinds of consistency works because they are difficult, but let us be confident that in the long run the work on variational inference will pay off. Mostly, I like proofs :)

footnotes

  1. The paper uses $N$ to denote the set of all possible contexts in a language, making some assumptions about finiteness. I'm simplifying here a bit and talking about the contexts in a single sentence which is smaller. The implications are the same, I think.

  2. Note that this idea of a probability matrix assumes that we're using teacher-forcing, which is usual practice in language modeling. In other words, the distribution over word 2 only depends on the context vector, which only depends on the context vector from the previous state and the correct word 1, not on our choice of word 1. This isn't a hack - it's the right thing to do when trying to estimate the true log marginal.

  3. Does anyone know of an open-source implementation of BBVI (i.e. control variates and Rao-Blackwellization)? Jeffrey and I worked on this, and we think it's very very expensive to compute the 'correct' control variate as recommended in their paper, based on how PyTorch and Tensorflow are implemented, but we might be missing a detail.

https://rachitsingh.com/papers-2017/
PyTorch Internals, cuRAND, and numerical instability
Random sampling

I've been working lately to implement random samplers from a number of distributions in PyTorch, both on CPU and CUDA. This is a topic near and dear to my heart, since it has caused me a lot of trouble multiple times. Once this PR is merged, I'll post an explanation/notebook of why this is important.

Here's a brief summary of the motivation:

  1. We want to sample from distributions like $\operatorname{Beta}(a, b)$. However, it's tricky, because up until recently PyTorch could only sample from a few basic distributions (Uniform, Normal, Exponential, etc.). This is a problem because most fast sampling algorithms for more complex distributions work via rejection sampling (or variants, like ARS), or via inverse transform sampling. The first is tricky because if you want to do it in parallel in pure PyTorch, you need to implement a tricky masking method, and the second is tricky because the inverse CDF is often hard to compute.
  2. Failing that, we can fork out to Numpy. After all, PyTorch seamlessly integrates with Numpy, which has long had excellent support for distributions (more on this later). However, sampling in Numpy involves an expensive CPU-GPU copy, which was actually significant in our models. In our work, the baseline used a Beta distribution, so it would be unfair to compare with this large performance hit.
  3. Finally, failing that, we can write C/CUDA code to sample, and link against PyTorch. That's exactly what we did. The downside of this is that CUDA random number generation is a little tricky, and NVIDIA's cuRAND library only implements a few random number generators. Also, since I am only a makefile novice, it took me forever to get it to compile on Odyssey, and promptly didn't work when I tried to use it on a different environment.

So, my goal lately is to port some of the knowledge gained to PyTorch proper. That way, other researchers can get random $\operatorname{Beta}(a, b)$ samples, fast, without having to jump through all the hoops.

PyTorch internals

PyTorch as a project is pretty complex, but can be surprisingly easy to contribute to if you know where to look. Unfortunately the documentation on internals is sparse 1, and there's two things that make it difficult: there's a mixture of C/C++/CUDA/Python code throughout, and it's glued together with a lot of codegen.

Why is this necessary? PyTorch is a Python library that communicates with C/C++ code (for fast CPU operations), and CUDA (for fast GPU operations). Since there are many data types supported, a lot of the code would be tedious: all of

THFloatTensor * add(THFloatTensor *a, THFloatTensor *b);
THDoubleTensor * add(THDoubleTensor *a, THDoubleTensor *b);
THCFloatTensor * add(THCFloatTensor *a, THCFloatTensor *b);
...

probably have the same implementation. Imagine repeating that 15 times! So not only are the FFI interfaces generated, but the function signatures and implementations too.

Very recently, ATen has made the story somewhat simpler by leveraging C++11 and namespacing to eliminate macros 2.

Here's a few notes I found useful while trying to understand how the build works:

  1. There are 2 different codegen systems: cwrap for generating Python interfaces for some underlying code, and .yaml for an interface from Variable to ATen. So, the torch/csrc/generic/**/*.cwrap files generate Python interfaces and versions of the THTensor_(...) methods for each type, which are dispatched based on the type used. You can jump into that via generate_code.py here.

    For the .yaml files, ATen builds its own interface via this file and outputs Declarations.yaml. Then, generate_code.py reads Declarations.yaml and writes the corresponding Python interface, using gen_variable_type and the derivatives.yaml file. The latter also has information about what the gradient of an operation is.

  2. While building, all the information in CONTRIBUTING.md is very helpful in keeping iteration time down. Also helpful: rewrite build_deps inside setup.py to just build your component (e.g. ATen). Sometimes it gets screwed up and running python setup.py clean is the remedy.

  3. The ATen codegen (starting with gen.py, but mostly in function_wrapper.py) generates the glue that dispatches the correct function based on types. After building, you can find these files in torch/lib/build/aten/src/ATen/ATen/. If you want to mess with the generation, you can modify function_wrapper.py: just find the spot where the corresponding code is generated, and modify options to do what you need. Note that to change just one code path, you'll need to modify many of the codegen points, so look for all of them (Functions.h, CPU[Type]Type.h, etc.).

Mostly I figured this out by running the build, using ag -G [something] [term], and find . -name "[regexp]". If you're poking around, they will likely be useful as well. NOTE: by default, ag or rg will ignore the files in your .gitignore. This includes generated build files!

A story about RNG

Recently I was implementing a Poisson sampler using essentially rejection sampling, and found that it didn't work. Here's the code:

__device__ int64_t sample_poisson(double lambda, curandStateMtgp32 *states) {
  if (lambda < 10) {
    double enlam = std::exp(-lambda);
    int64_t X = 0;
    double prod = 1.0;
    double U = 0;
    while (1) {
      U = curand_uniform_double(&states[blockIdx.x]);
      prod *= U;
      if (prod > enlam) {
        X += 1;
      }
      else {
        return X;
      }
    }
  }
  ... // more special case code for values of lambda
}

In particular, if a thread didn't exit in the first or second samples, it would never exit the while loop. I spent a while debugging, and realized that even though calls to curand_uniform_double were uniformly distributed in isolation, adding rejection sampling would cause it to repeat values. The calls are curand_uniform_double(state) for some RNG state state, but state was fine since it generated uniform doubles in isolation. PyTorch uses a MTGP32-based sampler, so I eventually looked in the docs and found this line:

"At a given point in the code, all threads in the block, or none of them, must call this function."

So, what was happening is that threads that returned early didn't call the function, so it was undefined behavior. This means rejection sampling is hard! However, there's a solution. There's an alternative call, curand_mtgp32_single_specific, which takes a generator state, an index, and a count of the total number of threads that call it. As long as each index is unique and adds up the thread count, this will give uniformly distributed floats as expected. However, we do need to be a bit careful about how to synchronize because of warp divergence.

__device__ int64_t sample_poisson(double lambda, curandStateMtgp32 *states, int num_threads) {
  __shared__ int thread_count;
  if (threadIdx.x == 0) thread_count = num_threads;
  int64_t X = 0;
  int idx = threadIdx.x;
  float U = 0;
  float enlam = std::exp(-lambda);
  float prod = 1.0;

  while (thread_count != 0) {
    U = curand_mtgp32_single_specific(&states[blockIdx.x], idx, thread_count);
    prod *= U;
    if (prod > enlam) {
      X += 1;
    }
    __syncthreads();
    if (idx == 0) {
      thread_count = 0;
    }
    __syncthreads();
    if (prod > enlam) {
      idx = atomicAdd(&thread_count, 1); // counts 'living' threads
    }
  }
}

While it's neat, for a few reasons unfortunately it's not quite appropriate for PyTorch, so we'll look into other solutions. For the Poisson, at least, there's a curand_poisson which implements it natively for us.

Some thoughts

One problem that bothered me for more than a week on the IBP project was that our implementation of Beta BBVI went haywire when I used my CUDA sampler. So, following Finale's advice, I made some qq-plots, but couldn't see any real issues. The reason: was sampling using the identity

$$z \sim \operatorname{Beta}(a, b) \implies z \sim \frac{\operatorname{Gamma}(a)}{\operatorname{Gamma}(a) + \operatorname{Gamma}(b)}$$

since you know, that's what I learned in Stat 210. But! This is numerically unstable when both $a, b \leq 1$. The solution was found while digging through Numpy's code here, which taught me to respect my elders, or at least to respect Numpy.

I wonder whether there's any work still going on for fast random number sampling. It's not something I'm directly interested in, but something I'm curious about.

Another fun story: when later trying to calculate log of the Beta function, I was on my guard and checked out the Cephes implementation, which is roughly 30 years old. At the top it says:

"Direct inquiries to 30 Frost Street, Cambridge, MA 02140"

which is about 2 blocks from where I live.

  1. There's some other blog posts by the PyTorch folks here, definitely also worth checking out.

  2. Which are the devil. My operating systems course, as excellent as it was, was entirely in C and implemented arrays via macros.

https://rachitsingh.com/pytorch/
ELBO Surgery
td { padding: 5px; font-family: monospace; font-size: 1.5rem; text-align: right; } th { text-align: right; padding: 0px 5px; font-size: 1.5rem; } th.left_column { text-align: right; } figure { margin: 0px 20px; max-width: 50rem; } img[src*="#smaller"] { width: 50%; margin: auto; }

tldr: The ubiquitous isotropic Gaussian prior for generative models doesn't make sense / doesn't work, which motivates work on priors.

At NIPS, Dawen Liang mentioned Hoffman & Johnson's ELBO surgery paper offhand while talking about tuning KL divergences, and it's very interesting, so I thought I'd go over it. It's very clearly written, so I won't go into any of the derivations, but instead offer my interpretation.

Motivation

I worked in the past on applying variational inference and comparing it to models trained via MAP/MLE inference. I decomposed the evidence lower bound (ELBO) as:

$$\mathcal{L} = \frac{1}{N}\sum_{i = 1}^N\left(\underbrace{\mathbb{E}_{q(z_n | x_n)}[\log p(x_n | z_n)]}_{\text{log-likelihood}} - \underbrace{\operatorname{KL}(q(z_n | x_n) || p(z_n))}_{\text{KL divergence}}\right)$$

This is, I think, the most common interpretation: split the ELBO into a reconstruction term and a KL divergence term. The first encourages the model to reconstruct the data, and the second regularizes the model, asking the posterior distribution over $z_n$ to have a certain shape, like a Gaussian. For example, in a VAE the second term is what prevents the model from just learning a Dirac delta-like posterior $q(z_n | x_n) \sim \mathcal{N}(x_n, 0.001)$ around the original value1.

In the NLP world, people have seen some problems though - when we have a very powerful generative model (e.g., an RNN), the KL divergence can vanish. This means the posterior $q(z_n | x_n) \approx p(z_n)$ learns nothing about the data, so the generative model $p(x_n | z_n)$ becomes like a language model. The usual trick is to anneal the KL divergence term in, so that that inference can be useful. A lot of people are unhappy with this because it adds extra hyperparameters and it feels really non-Bayesian.

Contribution

The contribution of this paper is the following observation: the KL divergence above measures the distance from the posterior for a single $z_n$ to the prior, but we really care about the KL divergence from the average posterior over all data points to the prior. So they define

$$q(z) = \frac{1}{N}\sum_{n = 1}^Nq(z_n | x_n)$$

which is the average posterior we see. The intuition here is that when we're trying to do inference, we shouldn't exactly be penalized for being very confident in $q(z_n | x_n)$. However, we want the average distribution to be close to the prior, so this term can go to 0 safely without worrying about whether the posterior has learned something. In fact, at the cost of a lot of extra computation, we can even safely set the prior to be this distribution, or let $p(z) \triangleq q(z)$!

Then, they view $n$, the index variable, as a random variable, where the interpretation is that our generative model samples $n \sim \operatorname{Unif}[N]$, and then picks a $z_n$ from $p(z)$. This isn't totally intuitive, but it makes more sense on the inference side, which we'll see below. Finally, they decompose the second term further as follows:

$$\mathcal{L} = \underbrace{\frac{1}{N}\sum_{i = 1}^N E_{q(z_n | x_n)}[\log p(x_n | z_n)]}_{\text{log-likelihood}} - \underbrace{\vphantom{\sum_{i = 1}^N}E_{q(z)}[\operatorname{KL}(q(n | z) || p(n)))]}_{\text{index-code mutual information}} - \underbrace{\vphantom{\sum_{i = 1}^N}\operatorname{KL}(q(z) || p(z))}_{\text{marginal KL}}$$

(this is not an obvious derivation, but the math checks out). Here, $q(n | z)$ (which we can decompose using Bayes' law) can be interpreted as 'the distribution over which datapoint this $z$ belongs to'. The description 'index-code mutual information' comes from an alternative way to write the expression, but I like this one more. Also, they upper bound this value by $\log N$, a not insignificant quantity! This is 11 nats on MNIST.

Experiments

Finally, the most interesting section, which is the quantitative analysis: they apply the model to a set of the usual VAEs with an isotropic Gaussian prior used for binarized MNIST, and get the following results:

ELBO Average KL Mutual info. Marginal KL 2D latents -129.63 7.41 7.20 0.21 10D latents -88.95 19.17 10.82 8.35 20D latents -87.45 20.2 10.67 9.53

So, what's going on is that as we increase the number of latent dimensions to 10/20, the marginal KL gets large! Which means that the Gaussian prior is not good enough anymore. At least, that's my interpretation. This is interesting food for thought, since that gives a lot of evidence for a hunch that people have had for a while (and motivates work on new prior distributions, like our paper).

  1. Well, if the latent capacity is large enough. Otherwise it might learn, e.g. a PCA-like compression, or some other compression if the inference and generation nets are more crazy.

https://rachitsingh.com/elbo-surgery/
NIPS 2017

I'm starting this blog to share research ideas that I have, and some solutions to problems I find along the way. I've been helped immensely by other people's blogs in the past, and want to do the same. Also it'll give me a chance to communicate the way I approach problems, and hopefully people will give me alternative perspectives either by email (rachitsingh@outlook.com) or in the comments, once I figure out how that works.

I'm going to steadily post my thoughts on NIPS 2017 (summary: lots of exposure to new ideas, very glad I went); but as a start I'm going to explain how we went: Jeffrey Ling and I submitted a short paper on Indian Buffet Processes to the NIPS workshop on Advances in Approximate Bayesian Inference (AABI), and we were accepted. See the repo for runnable code and some more information. I'll blog about the IBP as well soon, possibly after we have our ArXiv preprint up (soon!).

https://rachitsingh.com/nips/