Main
Computational models are used to devise hypotheses about neural systems and to design experiments to investigate them. When building such models, a central question is how much detail they should include: Models of neural systems range from simple rate-based point neuron models to morphologically detailed biophysical neuron models. The latter provide fine-grained mechanistic explanations of cellular processes underlying neural activity, typically described as systems of ordinary differential equations1,[2](#ref-CR2 “Prinz, A. A., Bucher, D. & Marder, E. Similar network activity from disparate circu…
Main
Computational models are used to devise hypotheses about neural systems and to design experiments to investigate them. When building such models, a central question is how much detail they should include: Models of neural systems range from simple rate-based point neuron models to morphologically detailed biophysical neuron models. The latter provide fine-grained mechanistic explanations of cellular processes underlying neural activity, typically described as systems of ordinary differential equations1,2,3.
However, it has been highly challenging for neuroscientists to create biophysical models that can explain physiological measurements3,4,5 or that can perform computational tasks6,7. It is hardly ever possible to directly measure all relevant properties of the system with sufficient precision to constrain all parameters directly, necessitating the use of inference or fitting approaches to optimize free model parameters. However, finding the right parameters for even a single-neuron model with only a few parameters can be difficult5,8, and large-scale morphologically detailed biophysical network models may have thousands of free parameters governing the behavior of ion channels (for example, maximal conductance), synapses (for example, synaptic conductance or time constant) or neural morphologies (for example, radius or branch length).
Recently, in many domains of science such as particle physics, geoscience and quantum chemistry, differentiable, GPU-accelerated simulators have enabled parameter inference for even complicated models using automatic differentiation techniques9,10,11. Such differentiable simulators make it possible to train simulators with gradient descent methods from deep learning: Backpropagation of error (‘backprop’) makes the computational cost of computing the gradient of the model with respect to the parameters independent of the number of parameters, making it possible to efficiently fit large models. In addition, GPU acceleration allows computing the gradient for many inputs (or model configurations) in parallel, which allows fitting simulations to large datasets.
Numerical solvers for biophysical models in neuroscience are used extensively, and several software packages exist, in particular the commonly used Neuron simulation environment12. Yet, none of these simulators allows performing backprop, and currently used simulation engines are primarily CPU-based, with GPU functionality only added post hoc13,14,15. As a consequence, state-of-the-art methods for parameter estimation in biophysical neuron models are based on gradient-free approaches such as genetic algorithms8,16 or simulation-based inference17, which do not scale to models with many parameters.
Inspired by the capabilities of deep learning to adjust millions (or even billions) of parameters given large datasets, we here propose to optimize biophysical parameters with gradient descent. To this end, we developed Jaxley, a toolbox for biophysical simulation which, unlike previous simulation toolboxes for biophysical models, can compute the gradient with backprop. In addition, Jaxley leverages GPU acceleration to speed up training. We apply Jaxley to a series of tasks, ranging from fitting physiological data (that is, match experimental recordings such as voltage or calcium measurements8,18) to solving computational tasks7,19 (Fig. 1a). We show that gradient descent can be orders of magnitude more efficient than gradient-free methods, and that it enables training biophysical networks with 100,000 parameters. This unlocks possibilities for data-driven and large-scale biophysical simulations in neuroscience.
Fig. 1: Differentiable simulation enables training biophysical neuron models.
a, Schematic of goal: training biophysically detailed neural systems. b, Schematic of method: our simulator, Jaxley, can simulate biophysically detailed neural systems, and it can also perform backprop. c, Jaxley can parallelize simulations on (multiple) GPUs/TPUs, and it can just-in-time (JIT) compile code. d, Reconstruction of a CA1 neuron24 and responses to a step current obtained with the Neuron simulator and with Jaxley. Inset is a zoom-in view of the peak of the action potential. Scale bars, 3 ms and 30 mV. e, Left: time to run 10,000 simulations with Neuron on a CPU and with Jaxley on a GPU. Right: Simulation time (top) for the CA1 neuron shown in d and for a point neuron, as a function of the number of simulations. Bottom: same as top, for computing the gradient with backprop. f, Biophysically detailed network built from reconstructions of CA1 neurons (left) and its neural activity in response to step currents to the first layer (right). Runtimes were evaluated on an A100 GPU. M, million; ML, machine learning; Sim., simulation.
Results
Jaxley is a differentiable simulator for neuroscience
Jaxley is a Python toolbox for simulation and training of biophysical neuron models. Jaxley implements numerical routines required for efficiently simulating biophysically detailed neural systems, so-called implicit Euler solvers, in the deep learning Python framework JAX[20](https://www.nature.com/articles/s41592-025-02895-w#ref-CR20 “Bradbury, J. et al. JAX: composable transformations of Python+NumPy programs. GitHub http://github.com/jax-ml/jax/
(2018).“). The automatic differentiation capabilities of JAX enable Jaxley to use backprop to efficiently compute the gradient with respect to any biophysical parameter, including ion channel, synaptic or morphological parameters (Fig. 1b).
For computational speed, Jaxley implements differential equations such that networks, parameter sets or input stimuli can be processed in parallel on GPUs, providing speedups for datasets (via stimulus parallelization) or for parameter sweeps (via parameter parallelization)13,15. Jaxley further speeds up simulation and training with just-in-time compilation (Fig. 1c).
Training biophysical models with gradient descent leads to instabilities resulting from parameters having different scales, networks having a large computation graph[21](https://www.nature.com/articles/s41592-025-02895-w#ref-CR21 “Hazelden, J., Liu, Y. H., Shlizerman, E. & Shea-Brown, E. Evolutionary algorithms as an alternative to backpropagation for supervised training of biophysical neural networks and neural ODEs. Preprint at https://arxiv.org/abs/2311.10869
(2023).“) and loss surfaces being non-convex. Jaxley implements methods that have been developed to overcome these specific issues in deep neural networks (Extended Data Fig. 1). For example, it implements parameter transformations, multilevel checkpointing22 and optimizers for non-convex loss surfaces (for example, Polyak gradient descent23). Furthermore, we designed Jaxley with a user-friendly interface, allowing neuroscientists to build biophysical models (for example, for inserting recordings, stimuli and channels into various branches or cells, or implementing different connectivity structures such as sparse or dense connectivity) and to use automatic differentiation and GPU parallelization. In a dedicated library open to the community, it also implements a growing set of ion channel and synapse models. Jaxley is fully written in Python, which will make it easy for the community to use and to add functionality to it. Jaxley is openly available at https://github.com/jaxleyverse/jaxley/.
Jaxley is accurate, fast and scalable
We benchmarked the accuracy, speed and scalability of Jaxley for simulation of biophysical models. First, we evaluated the accuracy of Jaxley by creating biophysically detailed multicompartment models of a CA1 pyramidal cell in the rat hippocampus24,25 and of four layer 5 neurons in the mouse visual area from the Allen Cell Types Database[26](https://www.nature.com/articles/s41592-025-02895-w#ref-CR26 “Allen Institute for Brain Science. Allen Cell Types Database. Allen Brain Atlas http://celltypes.brain-map.org/
(2016).“). Every model contained sodium, potassium and leak channels in all branches. We stimulated the soma and recorded the voltage at three locations across the dendritic tree. Jaxley matched the voltages of the Neuron simulator at sub-millisecond and sub-millivolt resolution (Fig. 1d and Extended Data Fig. 2).
Next, we evaluated the simulation speed of Jaxley on CPUs and GPUs. We simulated the above-described CA1 cell and a single-compartment model for 20 ms. On a GPU, Jaxley was much faster for large systems or many parallel simulations, with a speedup of around two orders of magnitude (Fig. 1e and Supplementary Fig. 1). For single-compartment neurons, Jaxley could parallelize the simulation of up to 1 million neurons, thereby allowing fast parameter sweeps. On a CPU, Jaxley was at least as fast as Neuron.
We then evaluated the computational cost of computing the gradient with Jaxley. For backpropagation, the forward pass must be stored in-memory, which can easily correspond to terabytes of data for large neural systems. To overcome this, Jaxley implements multilevel checkpointing22, which reduces memory usage by strategically saving and recomputing intermediate states of the system of differential equations. We found that depending on the simulation device (CPU/GPU/TPU), the number of simulations, the simulated time and the loss function, computing the gradient was between 3 and 20 times more expensive than running the simulation itself (Fig. 1e).
Finally, we show that in addition to parallelizing across parameters (or across stimuli), Jaxley can parallelize across branches or compartments in a network. We built a network consisting of 2,000 morphologically detailed neurons with Hodgkin–Huxley dynamics, connected by 1 million biophysical synapses (3.92 million differential equation states in total; Fig. 1f). On a single A100 GPU, Jaxley computed 200 ms (that is, 8,000 steps at Δt = 0.025 ms) of simulated time in 21 s. We then used backprop to compute the gradient with respect to all membrane and synaptic conductances in this network (3.2 million parameters in total), which took 144 s. Estimating the gradient with finite differences—as would be required for packages that do not support backprop—would take more than 2 years (3.2 million forward passes, 21 s each). With simplified morphologies, Jaxley can scale backprop to networks with many millions of synapses on a single GPU (Extended Data Fig. 3).
Fitting single-neuron models to intracellular recordings
Having demonstrated the accuracy and speed of Jaxley, we applied it to a series of tasks that demonstrate how Jaxley opens up opportunities for building task-driven or data-driven biophysically detailed neuroscience models in a range of scenarios. As a first proof-of-principle, we applied Jaxley to fit single-neuron models with few parameters. We built a biophysical neuron model based on a reconstruction of a layer 5 pyramidal cell (L5PC) (Fig. 2a). The model had nine different channels in the apical and basal dendrites, the soma and the axon16, with a total of 19 free parameters. We learned these parameters from a synthetic somatic voltage recording given a somatic step-current stimulus with a known set of ground-truth parameters (Fig. 2b).
Fig. 2: Inferring single-neuron models with gradient descent.
a, (Task 1) Morphology of an L5PC. b, Top: synthetic somatic voltage recording (black) and windows that are used to compute summary statistics (top). Bottom: fits obtained with gradient descent. Best fit in dark blue; fits from independent runs in light blue. Scale bars, 20 ms and 30 mV. c, Loss value of individual gradient descent runs (light blue), their minimum (dark blue), in comparison to the minimum loss across ten genetic algorithm runs (black). d, Neuron morphologies and patch-clamp recordings (black) in response to step currents from the Allen Cell Types Database. Gradient descent fit in blue. Additional models in Extended Data Fig. 5. Scale bars, 200 ms and 30 mV. e, (Task 2) Synthetic conductance profile of an L5PC morphology. f, Simulated voltages given the synthetic conductance profile after 1.5 ms and 2.5 ms. g, Predicted voltages of gradient descent fit. h, Ground-truth (gt) conductance profile as a function of distance from the soma (black) and 90% confidence interval obtained with multi-chain gradient-based Hamiltonian Monte Carlo. i, Loss for gradient descent and genetic algorithm. j, (Task 3) Nonlinearly separable input stimulus amplitudes (left), and a simplified morphology with 12 compartments (right). k, Voltage traces of model found with gradient descent. l, Decision surface of the model reveals nonlinear single-neuron computation. m, Minimum loss across ten independent runs for gradient descent and genetic algorithm. alg., algorithm; desc., descent.
We used gradient descent to identify parameter sets that minimize the mean absolute error to summary statistics of the voltage trace. Because gradient descent requires differentiable summary statistics, but commonly used summary statistics of intracellular recordings—such as spike count—can be discrete or non-differentiable, we used the mean and standard deviation of the voltage in two time windows17. Starting from randomly initialized parameters, gradient descent required only nine steps (median across ten runs) to find models whose voltage traces are visually similar to the observation (Fig. 2b). A state-of-the-art indicator-based genetic algorithm (IBEA)16 required similarly many iterations, although each iteration of the genetic algorithm used ten simulations. As a consequence, gradient descent required almost ten times fewer simulations than the genetic algorithm, and, despite the additional cost of backpropagation, found good parameter sets in less runtime than the genetic algorithm on a CPU (Fig. 2c and Extended Data Fig. 4a).
We then used gradient descent to identify parameters that match patch-clamp recordings of four cells from the Allen Cell Types Database[26](https://www.nature.com/articles/s41592-025-02895-w#ref-CR26 “Allen Institute for Brain Science. Allen Cell Types Database. Allen Brain Atlas http://celltypes.brain-map.org/
(2016).“). We inserted the same set of ion channels, but, to account for the diversity and complexity of the experimental recordings, we made six additional parameters trainable and used a loss function based on dynamic time warping (DTW). To lower the fitting time caused by the length of the recordings (1 s versus 100 ms in the synthetic experiments), we initially fitted only the first 200 ms of these traces and then added an additional step to fit the entire trace.
We first used gradient descent with a low computational budget (ten runs with ten iterations each; loss in Extended Data Fig. 4b) and found that the resulting traces roughly matched the firing rate of experimental recordings, but did not yet match other features such as spike frequency adaptation (Extended Data Fig. 5). To improve the fits, we used the ability of Jaxley to parallelize several fitting runs. We parallelized 1,000 gradient descent runs on a GPU and found parameter sets whose voltage traces closely resembled experimental recordings (Fig. 2d and Extended Data Fig. 5). Using Jaxley, we also parallelized a genetic algorithm on a GPU and found that the resulting fits were of similar quality (Extended Data Fig. 5). Overall, these results demonstrate the ability of gradient descent to fit biophysical models to intracellular recordings, being competitive with state-of-the-art genetic algorithms even on tasks for which those have been extensively optimized.
Fitting single-neuron models with many parameters
How does gradient descent scale to models with large numbers of parameters? We demonstrate here that, in contrast to genetic algorithms, gradient descent can optimize a single-neuron model with 1,390 parameters.
We used the above-described model of an L5PC. Unlike in the above experiments, we fit the maximal conductance of ion channels in every branch in the morphology, thereby allowing us to model effects of nonuniform conductance profiles27. This increased the number of free parameters to 1,390. To generate a synthetic recording, we assigned a different maximal conductance to each branch (sampled from a Gaussian process), depending on the distance from the soma (Fig. 2e). We recorded the voltage at every branch of the model in response to a 5-ms step-current input (Fig. 2f). Experimentally, such data could be obtained, for example, through voltage imaging.
We used gradient descent to identify parameters that match this recording, with a regularizer that penalizes the difference between parameter values in neighboring branches27. Despite the large number of parameters, gradient descent found a parameter set whose voltage response closely matched the observed voltage throughout the dendritic tree (Fig. 2g). To understand how much the whole-cell voltage recording constrains the parameters, we used Bayesian inference (implemented with gradient-based Hamiltonian Monte Carlo) to infer an ensemble of parameter sets all of which match the observed voltage (Extended Data Fig. 6). The resulting ensemble revealed regions along the dendritic tree at which the conductance profile was strongly constrained by the data (for example, the transient sodium channel, Fig. 2h, around 400 μm), but it also revealed conductance profiles that were only weakly constrained by the data (Fig. 2h and Extended Data Fig. 6)2. Finally, we compared our method with an indicator-based genetic algorithm8 and, as expected, we found that access to gradients leads to better convergence: While gradient descent converged to values of low loss within 100 iterations, the genetic algorithm had two orders of magnitude higher loss even after 500 iterations (Fig. 2i).
Nonlinear single-neuron computation
Next, we trained a single-neuron model to solve a nonlinear pattern separation task on its dendritic tree. While it has been demonstrated extensively that single-neuron models respond nonlinearly to inputs28, it has so far been difficult to train biophysically detailed neurons on a particular task. Here, we show that stochastic gradient descent enables training single-neuron models with dendritic nonlinearities to perform nonlinear computations.
We defined a simple morphology consisting of a soma and two dendrites, and inserted sodium, potassium and leak channels into all neurites of the cell. We then learned ion channel densities as well as length, radius and axial resistivity of every compartment (72 parameters in total) for the neuron to have low somatic voltage (−70 mV), when both dendrites were stimulated with step currents of intermediate strength, and high somatic voltage (35 mV) when one of the dendrites was stimulated strongly and the other one weakly (Fig. 2j). Therefore, the two classes were not linearly separable, requiring the neuron to perform a nonlinear computation.
After training the parameters with gradient descent, we found that the cell indeed learned to perform this task and spiked only when one dendrite was stimulated strongly (Fig. 2k), effectively having a nonlinear decision surface (Fig. 2l). We again compared gradient descent to an indicator-based genetic algorithm and found that gradient descent finds regions of lower loss more quickly than genetic algorithms (Fig. 2m).
Overall, these results show that gradient descent performs better than gradient-free methods in models with many parameters, opening up possibilities for studying at scale biophysical mechanisms throughout the full neuronal morphology.
Hybrid retina model of dendritic calcium measurements
So far, we have learned parameters of single-neuron models using small datasets consisting of few stimulus–response pairs. Many models of neural systems, however, consist of multiple neurons, and datasets can contain thousands of stimulus–response pairs7,29,30. Using a network model of the mouse retina, we demonstrate that Jaxley can simultaneously infer cell-level and network-level parameters, such that model simulations match large-scale datasets.
We consider transient Off alpha retinal ganglion cells (RGCs) in the mouse retina, which show compartmentalized calcium signals in their dendrites in response to visual stimulation18. To understand the mechanistic underpinning of this behavior, we built a hybrid model with statistical and mechanistic components: We modeled photoreceptors as a convolution with a Gaussian filter, bipolar cells as point neurons with a nonlinearity and an RGC as a morphologically detailed biophysical neuron, with six different ion channels distributed across its soma and dendritic tree (Fig. 3a). To model the fluorescence signal of the calcium indicator, we convolved the intracellular calcium (from the calcium channel of the model) with a calcium kernel (Fig. 3b).
Fig. 3: Hybrid model of the calcium responses of an RGC.
a, Schematic of experimental setup and hybrid model. b, Schematic of training procedure and loss function. c, Measured and model-predicted calcium response across 50 noise images (200 ms each). d, Left: calcium response (color map) of the trained hybrid model to a step current to a single branch indicated by step-current sketch. Middle: voltage activity of the model at two branches, one at the stimulus site and one at a distant branch. Right: intracellular calcium concentration in the same two recording sites. Scale bars, 50 ms, 30 mV and 0.025 mM. e, Receptive fields of the hybrid model. f, Pearson correlation coefficient between experimental data and model for train (top) and test (bottom) data, for a linear network, a multilayer perceptron, and the hybrid model. Error bars show the s.e.m. over seven datasets (Methods). Asterisks denote a statistically significant difference between mean correlations of hybrid model and multilayer perceptron (MLP; one-sided t-test at P < 0.05, P = {0.0030, 0.0044, 0.017, 0.066, 0.80}). a.u., arbitrary units; BCs, bipolar cells; intra., intracellular; MAE, mean absolute error; PRs, photoreceptors.
Using Jaxley, we trained the hybrid model to predict dendritic calcium on 15,000 pairs of checkerboard noise stimuli and calcium recordings. We learned synaptic conductances from the bipolar cells onto the RGC (287 synaptic parameters), as well as cellular RGC parameters (320 cell parameters). After training, we evaluated the trained model on a held-out test dataset. The model had a positive Pearson correlation coefficient with the experimental recording on 146 of 147 recording sites, with an average correlation of 0.25, and a maximum of 0.51 (Fig. 3c).
Next, we tested whether the trained model also reproduced the compartmentalized structure of calcium responses, which has been experimentally measured18. We stimulated the model at a distal branch and recorded the model calcium response across all branches of the cell. We found that the calcium signal in response to local stimulation did not propagate through the entire cell (Fig. 3d). In addition, the model receptive fields did not cover the entire cell and were roughly centered around the recording locations (Fig. 3e). This demonstrated a compartmentalized response of the model and qualitatively matched the receptive fields obtained from experimental measurements18.
The mechanistic components of the model, including the anatomical structure, provide an inductive bias. Therefore, we investigated whether this inductive bias of the hybrid model could lead to better generalization to new data, especially when training data are scarce. We trained a linear model, a two-layer perceptron and the hybrid model on reduced datasets of recordings and compared their performance on a held-out test set (Fig. 3f). While the linear model and the perceptron performed better than the hybrid model on training data, the hybrid model performed better on held-out test data, when little training data were available. These results indicate that the inductive bias brought by the hybrid model effectively can limit the amount of overfitting in the model, suggesting that hybrid components could be used as regularizers for deep neural network models of neural systems31.
Our results demonstrate that gradient descent enables fitting networks of biophysical neurons to large calcium datasets and allows simultaneous learning of cell-level and network-level parameters.
Biophysical RNNs solve working memory tasks
To understand how computations are implemented in neural circuits, computational neuroscientists aim to train models to perform tasks7,19. In particular, recurrent neural networks (RNNs) have been used to form hypotheses about population dynamics underlying cognition. Typically, such RNNs consist of point neurons with rate-based or simplified spiking dynamics, which prevents studying the contribution of channel dynamics or cellular processes. We here show how Jaxley makes it possible to train biophysical models of neuronal networks to perform such tasks.
We implemented in Jaxley an RNN consisting of Hodgkin–Huxley-type neurons with a simplified apical and basal dendrite, with each neuron equipped with a variety of voltage-gated ion channels4. We sparsely connected the recurrent network with conductance-based synapses and obtained the outputs from passive readout units (Fig. 4a).
Fig. 4: RNN models with Hodgkin–Huxley-type neurons perform working memory tasks.
a, Schematic of the RNN. b, Autonomous dynamics of the recurrently connected neurons before learning the parameters for different values of synaptic gain. c, Maximal Lyapunov exponent indicates transition to chaos with increasing synaptic gain. d, (Task 1) Evidence integration task. Gaussian noise stimulus (top) and voltage traces of the two readout neurons. Response period in gray. e, Histogram of initial and trained input, recurrent and output weights of the network. f, Psychometric curve showing the fraction of times the RNN reported the summed input to be greater than zero (r1 > r2), as a function of the stimulus mean. g, (Task 2) Delayed-match-to-sample task. Two stimuli separated by a delay (for two of four input patterns), raster plots of the trained network activity, and readout neuron prediction.
We first investigated the dynamics in this biophysical RNN before training. As with rate-based RNNs, these dynamics were strongly dependent on a global scaling factor (called ‘gain’) of all recurrent synaptic maximal conductances32. Our RNN transitioned from a stable to a chaotic regime when the gain was increased, with an intermediate region, where networks displayed regular firing (Fig. 4b). The ability of Jaxley to perform automatic differentiation allowed us to quantify the stability of networks by numerically computing Lyapunov exponents33 (Fig. 4c and Supplementary Fig. 2).
We then trained the biophysical RNN to perform two working memory tasks, starting with a perceptual decision-making task requiring evidence integration over time34. We built a network of 20 recurrent neurons and stimulated each recurrent neuron with a noisy time series with either positive or negative mean value (Fig. 4d). We trained input weights, recurrent weights and readout weights (109 parameters) such that the network learned to differentiate between positive and negative stimuli during a response period after 500 ms. Despite the long time horizon of this task (500 ms, corresponding to 20,000 time steps of the simulation), gradient descent found parameters such that the RNN was able to perform the task (Fig. 4d and Supplementary Fig. 3), with the voltage in the readout neurons differentiating the input means with 99.9% accuracy across 1,000 trials. To solve this evidence integration task, some input, recurrent and output weights were pushed toward zero and the remaining weights were pushed toward their positive and negative constraints during training (Fig. 4e).
We also evaluated the generalization abilities of the trained biophysical RNN. We varied the mean value of the positive and negative stimuli and found that the relationship between average response and stimulus mean closely resembled the well-known sigmoidal psychometric curve of decision-making, where the network more often failed when the stimulus had a lower signal-to-noise ratio (Fig. 4f). In addition, the trained RNN generalized to evidence integration tasks of longer durations (Extended Data Fig. 7).
We next used the RNN to solve a more challenging working memory task, a delayed-match-to-sample task, where the RNN had to maintain information over an extended period of time19,35. We trained the RNN to classify patterns, consisting of two step-current inputs with a delay between them, into matching (same identity of the inputs) or non-matching (different identity of the inputs; Fig. 4g). We used curriculum learning to solve this task. By training input, recurrent and readout weights, as well as synaptic time constants of a network with 50 recurrent neurons (542 parameters), we found parameter sets that solved the task and correctly classified all four patterns (Fig. 4g). We then inspected the population dynamics of the trained biophysical RNN and found that the network used a form of transient coding to solve the task (Extended Data Fig. 8).
Overall, these results demonstrate that gradient descent allows training RNNs with biophysical detail to solve working memory tasks. This will allow a more quantitative investigation of the role of cellular mechanisms contributing to behavioral and cognitive computations.
Training biophysical networks with 100,000 parameters
Finally, we show that gradient descent enables training of large biophysical models with thousands of cellular-level and network-level parameters on machine learning-scale datasets to solve classical computer vision tasks such as image recognition.
We implemented a feedforward biophysical network model in Jaxley and trained it to solve the classical MNIST task, without artificial nonlinearities such as ReLU activations. The network had three layers: The input and output layers consisted of neurons with ball-and-stick morphologies and the hidden layer consisted of 64 morphologically detailed models obtained from reconstructions of CA1 cells (Fig. 5a)15,24. The network was interconnected by biophysical synapses. We trained sodium, potassium and leak conductances of every branch in the circuit (55,000 parameters), as well as all synaptic weights (51,000 parameters).
Fig. 5: Training biophysically detailed networks to solve computer vision tasks.
a, Biophysical network consisting of 28 × 28 input neurons, 64 morphologically detailed hidden neurons, and 10 output neurons. b, Voltages (left and middle) and softmax probabilities of the output neurons (right) measured at the somata of neurons in the trained network in response to an image labeled as ‘0’. Red color in the output layer indicates the image label. Scale bars, 2 ms and 80 mV. c, Histograms of test-set accuracy of 50 linear networks (gray), 50 multilayer perceptrons with 64 hidden neurons and ReLU activations (black) and the biophysical network (blue). d, Histogram of parameters before (black) and after (green) training. gsyn1, gsyn2, synaptic conductances in the first and second layer, respectively. e, Test-set accuracy for trained network when subsets of parameters are reset to their initial value. Blue line indicates the fully trained network. f, Clean image (top, left) and adversarial image perturbation (top, right), as well as corresponding voltage traces of the output neurons (bottom). g, Accuracy across 128 test-set examples, as a function of the norm of the adversarial image perturbation. Biophys., biophysical.
We simulated the network for 10 ms, as this was the time it took for the stimulus to propagate through the network. After training with stochastic gradient descent, upon being stimulated with a ‘0’ digit, the softmax of the voltages of the output neurons indicated a high probability for the digit ‘0’ (Fig. 5b). The network achieved an accuracy of 94.2% on a held-out test dataset, which is higher than a linear classifier, demonstrating that the biophysical network used its nonlinearities to improve classification performance. The biophysical network, however, performed slightly worse than a multilayer perceptron with ReLU nonlinearities, suggesting either that the spike/no-spike nonlinearities are more difficult to train than ReLU nonlinearities, or that the (binary) spike/no-spike representations lead to lower bandwidth than graded ReLU activations (Fig. 5c). After training, the biophysical network also developed interpretable hidden-layer tuning (Extended Data Fig. 9).
How do the learned parameters of the biophysical network contribute to its ability to classify MNIST digits? Surprisingly, we found that the ranges of the trained synaptic parameters were roughly similar to the ranges of the untrained network and that the membrane channel conductances were roughly centered around their initial value (Fig. 5d). This does not mean that the learned values of these parameters do not contribute to the learned network dynamics: While the ranges of parameters did not change substantially, individual parameters could vary largely (Extended Data Fig. 10). Furthermore, resetting subsets of parameters to their initial value reduced classification performance, sometimes to chance-level accuracy (10%; Fig. 5e). This indicates that biophysical simulations built purely from the aggregate statistics of measured parameter values could not be sufficient for the mod