# Difference between revisions of "conditional neural process"

(→Model) |
(→Experimental Result II: Image Completion for Digits) |
||

Line 98: | Line 98: | ||

== Experimental Result II: Image Completion for Digits == | == Experimental Result II: Image Completion for Digits == | ||

− | [[File:002.jpg| | + | [[File:002.jpg|600px|center]] |

They also tested CNP on the MNIST dataset and use the test | They also tested CNP on the MNIST dataset and use the test | ||

Line 136: | Line 136: | ||

exploration, but it already produces better prediction results | exploration, but it already produces better prediction results | ||

than selecting the conditioning points at random. | than selecting the conditioning points at random. | ||

− | |||

== Experimental Result III: Image Completion for Faces == | == Experimental Result III: Image Completion for Faces == |

## Revision as of 18:40, 20 November 2018

## Contents

## Introduction

To train a model effectively, deep neural networks typically require large datasets. To mitigate this data efficiency problem, learning in two phases is one approach: the first phase learns the statistics of a generic domain without committing to a specific learning task; the second phase learns a function for a specific task, but does so using only a small number of data points by exploiting the domain-wide statistics already learned. Taking a probabilistic stance and specifying a distribution over functions (stochastic processes) is another approach -- Gaussian Processes being a commonly used example of this. Such Bayesian methods can be computationally expensive, however.

The authors of the paper propose a family of models that represent solutions to the supervised problem, and an end-to-end training approach to learning them that combines neural networks with features reminiscent of Gaussian Processes. They call this family of models Conditional Neural Processes (CNPs). CNPs can be trained on very few data points to make accurate predictions, while they also have the capacity to scale to complex functions and large datasets.

## Model

Consider a data set [math] \{x_i, y_i\} [/math] with evaluations [math]y_i = f(x_i) [/math] for some unknown function [math]f[/math]. Assume [math]g[/math] is an approximating function of f. The aim is yo minimize the loss between [math]f[/math] and [math]g[/math] on the entire space [math]X[/math]. In practice, the routine is evaluated on a finite set of observations.

Let training set be [math] O = \{x_i, y_i\}_{i = 0} ^ n-1[/math], and test set be [math] T = \{x_i, y_i\}_{i = n} ^ {n + m - 1}[/math].

P be a probability distribution over functions [math] F : X \to Y[/math], formally known as a stochastic process. Thus, P defines a joint distribution over the random variables [math] {f(x_i)}_{i = 0} ^{n + m - 1}[/math]. Therefore, for [math] P(f(x)|O, T)[/math], our task is to predict the output values [math]f(x_i)[/math] for [math] x_i \in T[/math], given [math] O[/math],

## Conditional Neural Process

Conditional Neural Process models directly parametrize conditional stochastic processes without imposing consistency with respect to some prior process. CNP parametrize distributions over [math]f(T)[/math] given a distributed representation of [math]O[/math] of fixed dimensionality. Thus, the mathematical guarantees associated with stochastic processes is traded off for functional flexibility and scalability.

CNP is a conditional stochastic process [math]Q_\theta[/math] defines distributions over [math]f(x_i)[/math] for [math]x_i \in T[/math]. For stochastic processs, we assume [math]Q_theta[/math] is invariant to permutations, and in this work, we generally enforce permutation invariance with respect to [math]T[/math] be assuming a factored structure. That is, [math]Q_theta(f(T) | O, T) = \prod _{x \in T} Q_\theta(f(x) | O, x)[/math]

In detail, we use the following archiecture

[math]r_i = h_\theta(x_i, y_i)[/math] for any [math](x_i, y_i) \in O[/math], where [math]h_\theta : X \times Y \to \mathbb{R} ^ d[/math]

[math]r = r_i * r_2 * ... * r_n[/math], where [math]*[/math] is a commutative operation that takes elements in [math]\mathbb{R}^d[/math] and maps them into a single element of [math]\mathbb{R} ^ d[/math]

[math]\Phi_i = g_\theta[/math] for any [math]x_i \in T[/math], where [math]g_\theta : X \times \mathbb{R} ^ d \to \mathbb{R} ^ e[/math] and [math]\Phi_i[/math] are parameters for [math]Q_\theta[/math]

Note that this architecture ensures permutation invariance and [math]O(n + m)[/math] scaling for conditional prediction. Also, [math]r = r_i * r_2 * ... * r_n[/math] can be computed in [math]O(n)[/math], this architecture supports streaming observation with minimal overhead.

We train [math]Q_\theta[/math] by asking it to predict [math]O[/math] conditioned on a randomly
chosen subset of [math]O[/math]. This gives the model a signal of the uncertainty over the space X inherent in the distribution
P given a set of observations. Thus, the targets it scores [math]Q_\theta[/math] on include both the observed
and unobserved values. In practice, we take Monte Carlo
estimates of the gradient of this loss by sampling [math]f[/math] and [math]N[/math].
This approach shifts the burden of imposing prior knowledge

from an analytic prior to empirical data. This has the advantage of liberating a practitioner from having to specify an analytic form for the prior, which is ultimately intended to summarize their empirical experience. Still, we emphasize that the [math]Q_\theta[/math] are not necessarily a consistent set of conditionals for all observation sets, and the training routine does not guarantee that.

In summary,

1. A CNP is a conditional distribution over functions trained to model the empirical conditional distributions of functions [math]f \sim P[/math].

2. A CNP is permutation invariant in [math]O[/math] and [math]T[/math].

3. A CNP is scalable, achieving a running time complexity of [math]O(n + m)[/math] for making [math]m[/math] predictions with [math]n[/math] observations.

## Experimental Result I: Function Regression

Classical 1D regression task that used as a common baseline for GP is our first example. They generated two different datasets that consisted of functions generated from a GP with an exponential kernel. In the first dataset they used a kernel with fixed parameters, and in the second dataset the function switched at some random point. on the real line between two functions each sampled with different kernel parameters. At every training step they sampled a curve from the GP, select a subset of n points as observations, and a subset of t points as target points. Using the model, the observed points are encoded using a three layer MLP encoder h with a 128 dimensional output representation. The representations are aggregated into a single representation [math]r = \frac{1}{n} \sum r_i[/math] , which is concatenated to [math]x_t[/math] and passed to a decoder g consisting of a five layer MLP.

Two examples of the regression results obtained for each of the datasets are shown in the following figure.

They compared the model to the predictions generated by a GP with the correct hyperparameters, which constitutes an upper bound on our performance. Although the prediction generated by the GP is smoother than the CNP's prediction both for the mean and variance, the model is able to learn to regress from a few context points for both the fixed kernels and switching kernels. As the number of context points grows, the accuracy of the model improves and the approximated uncertainty of the model decreases. Crucially, we see the model learns to estimate its own uncertainty given the observations very accurately. Nonetheless it provides a good approximation that increases in accuracy as the number of context points increases. Furthermore the model achieves similarly good performance on the switching kernel task. This type of regression task is not trivial for GPs whereas in our case we only have to change the dataset used for training

## Experimental Result II: Image Completion for Digits

They also tested CNP on the MNIST dataset and use the test set to evaluate its performance. As shown in the above figure the model learns to make good predictions of the underlying digit even for a small number of context points. Crucially, when conditioned only on one non-informative context point the model’s prediction corresponds to the average over all MNIST digits. As the number of context points increases the predictions become more similar to the underlying ground truth. This demonstrates the model’s capacity to extract dataset specific prior knowledge. It is worth mentioning that even with a complete set of observations the model does not achieve pixel-perfect reconstruction, as we have a bottleneck at the representation level. Since this implementation of CNP returns factored outputs, the best prediction it can produce given limited context information is to average over all possible predictions that agree with the context. An alternative to this is to add latent variables in the model such that they can be sampled conditioned on the context to produce predictions with high probability in the data distribution.

An important aspect of the model is its ability to estimate
the uncertainty of the prediction. As shown in the bottom
row of the above figure, as they added more observations, the variance
shifts from being almost uniformly spread over the digit
positions to being localized around areas that are specific
to the underlying digit, specifically its edges. Being able to
model the uncertainty given some context can be helpful for
many tasks. One example is active exploration, where the
model has a choice over where to observe.
They tested this by
comparing the predictions of CNP when the observations
are chosen according to uncertainty, versus random pixels. This method is a very simple way of doing active
exploration, but it already produces better prediction results
than selecting the conditioning points at random.

## Experimental Result III: Image Completion for Faces

They also applied CNP to CelebA, a dataset of images of
celebrity faces, and reported performance obtained on the
test set.

As shown in the above figure our model is able to capture

the complex shapes and colours of this dataset with predictions conditioned on less than 10% of the pixels being already close to ground truth. As before, given few context points the model averages over all possible faces, but as the number of context pairs increases the predictions capture image-specific details like face orientation and facial expression. Furthermore, as the number of context points increases the variance is shifted towards the edges in the image. An important aspect of CNPs demonstrated in Figure 5, is its flexibility not only in the number of observations and targets it receives but also with regards to their input values. It is interesting to compare this property to GPs on one hand, and to trained generative models (van den Oord et al., 2016; Gregor et al., 2015) on the other hand. The first type of flexibility can be seen when conditioning on subsets that the model has not encountered during training. Consider conditioning the model on one half of the image, fox example. This forces the model to not only predict pixel values according to some stationary smoothness property of the images, but also according to global spatial properties, e.g. symmetry and the relative location of different parts of faces. As seen in the first row of the figure, CNPs are able to capture those properties. A GP with a stationary kernel cannot capture this, and in the absence of observations would revert to its mean (the mean itself can be non-stationary but

usually this would not be enough to capture the interesting properties). In addition, the model is flexible with regards to the target input values. This means, e.g., we can query the model at resolutions it has not seen during training. We take a model that has only been trained using pixel coordinates of a specific resolution, and predict at test time subpixel values for targets between the original coordinates. As shown in Figure 5, with one forward pass we can query the model at different resolutions. While GPs also exhibit this type of flexibility, it is not the case for trained generative models, which can only predict values for the pixel coordinates on which they were trained. In this sense, CNPs capture the best of both worlds – it is flexible in regards to the conditioning and prediction task, and has the capacity to extract domain knowledge from a training set. We compare CNPs quantitatively to two related models: kNNs and GPs. As shown in Table 4.2.3 CNPs outperform the latter when number of context points is small (empirically when half of the image or less is provided as context). When the majority of the image is given as context exact methods like GPs and kNN will perform better. From the table we can also see that the order in which the context points are provided is less important for CNPs, since providing the context points in order from top to bottom still results in good performance. Both insights point to the fact that CNPs learn a data-specific ‘prior’ that will generate good samples even when the number of context points is very small.

## Experimental Result IV: Classification

Finally, they applied the model to one-shot classification using the Omniglot dataset. This dataset consists of 1,623 classes of characters from 50 different alphabets. Each class has only 20 examples and as such this dataset is particularly suitable for few-shot learning algorithms. They used 1,200 randomly selected classes as their training set and the remainder as our testing data set. This includes cropping the image from 32 × 32 to 28 × 28, applying small random translations and rotations to the inputs, and also increasing the number of classes by rotating every character by 90 degrees and defining that to be a new class. They generated the labels for an N-way classification task by choosing N random classes at each training step and arbitrarily assigning the labels 0, ..., N − 1 to each.

Given that the input points are images, they modified the architecture of the encoder h to include convolution layers as mentioned in section 2. In addition they only aggregated over inputs of the same class by using the information provided by the input label. The aggregated class-specific representations are then concatenated to form the final representation. Given that both the size of the class-specific representations and the number of classes are constant, the size of the final representation is still constant and thus the O(n + m) runtime still holds. The results of the classification are summarized in the following table CNPs achieve higher accuracy than models that are significantly more complex (like MANN). While CNPs do not beat state of the art for one-shot classification our accuracy values are comparable. Crucially, they reached those values using a significantly simpler architecture (three convolutional layers for the encoder and a three-layer MLP for the decoder) and with a lower runtime of O(n + m) at test time as opposed to O(nm)

## Conclusion

In this paper they had introduced Conditional Neural Processes, a model that is both flexible at test time and has the capacity to extract prior knowledge from training data.

We had demonstrated its ability to perform a variety of tasks including regression, classification and image completion. We compared CNPs to Gaussian Processes on one hand, and deep learning methods on the other, and also discussed the relation to meta-learning and few-shot learning. It is important to note that the specific CNP implementations described here are just simple proofs-of-concept and can be substantially extended, e.g. by including more elaborate architectures in line with modern deep learning advances. To summarize, this work can be seen as a step towards learning high-level abstractions, one of the grand challenges of contemporary machine learning. Functions learned by most Conditional Neural Processes conventional deep learning models are tied to a specific, constrained statistical context at any stage of training. A trained CNP is more general, in that it encapsulates the high-level statistics of a family of functions. As such it constitutes a high-level abstraction that can be reused for multiple tasks. In future work they are going to explore how far these models can help in tackling the many key machine learning problems that seem to hinge on abstraction, such as transfer learning, meta-learning, and data efficiency.