---
date: '2024-11-11'
description: tidbits
id: PyTorch
modified: 2026-06-05 15:08:20 GMT-04:00
tags:
  - ml
  - framework
title: PyTorch
created: '2024-11-11'
published: '2024-11-11'
pageLayout: default
slug: thoughts/PyTorch
permalink: https://aarnphm.xyz/thoughts/PyTorch.md
generator:
  quartz: v4.6.0
  hostedProvider: Cloudflare
  baseUrl: aarnphm.xyz
full: https://aarnphm.xyz/llms-full.txt
---
see also: [unstable docs](https://pytorch.org/docs/main/)

```python title="qk_score.py"
import torch

qk_scores_short = torch.randn(2048)
qk_scores_long = torch.randn(128000)

max_v = torch.max(qk_scores_short.max(), qk_scores_long.max())
qk_scores_short[0] = max_val
qk_scores_long[0] = max_val
qk_scores_short.softmax(0)[0], qk_scores_long.softmax(0)[0]
```

## `MultiMarginLoss`

Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between input $x$
(a 2D mini-batch `Tensor`) and output $y$ (which is a 1D tensor of target class indices, $0 \le y \le \text{x}.\text{size}(1) -1$):

For each mini-batch sample, loss in terms of 1D input $x$ and output $y$ is:

$$
\text{loss}(x,y) = \frac{\sum_{i} \max(0, \text{margin} - x[y] + x[i])^p}{x.\text{size}(0)}
\\
\because i \in \{0, \ldots x.\text{size}(0)-1\} \text{ and } i \neq y
$$

## `SGD`

[[thoughts/Nesterov momentum]] is based on [On the importance of initialization and momentum in deep learning](http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf)

<div class="ps-root" data-inline-macros=""><span type="button" class="clipboard-button ps-clipboard" aria-label="Copy pseudocode to clipboard"><svg width="16" height="16" viewBox="0 0 16 16" class="copy-icon"><use href="#github-copy"></use></svg><svg width="16" height="16" viewBox="0 0 16 16" class="check-icon"><use href="#github-check" fill-rule="evenodd" fill="rgb(63, 185, 80)"></use></svg></span><span class="ps-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><annotation encoding="application/x-tex">"\\begin{algorithm}\n\\caption{SGD in PyTorch}\n\\begin{algorithmic}\n\\State \\textbf{input:} $\\gamma$ (lr), $\\theta_0$ (params), $f(\\theta)$ (objective), $\\lambda$ (weight decay),\n\\State $\\mu$ (momentum), $\\tau$ (dampening), nesterov, maximize\n\\For{$t = 1$ to $...$}\n    \\State $g_t \\gets \\nabla_\\theta f_t(\\theta_{t-1})$\n    \\If{$\\lambda \\neq 0$}\n        \\State $g_t \\gets g_t + \\lambda\\theta_{t-1}$\n    \\EndIf\n    \\If{$\\mu \\neq 0$}\n        \\If{$t > 1$}\n            \\State $b_t \\gets \\mu b_{t-1} + (1-\\tau)g_t$\n        \\Else\n            \\State $b_t \\gets g_t$\n        \\EndIf\n        \\If{$\\text{nesterov}$}\n            \\State $g_t \\gets g_t + \\mu b_t$\n        \\Else\n            \\State $g_t \\gets b_t$\n        \\EndIf\n    \\EndIf\n    \\If{$\\text{maximize}$}\n        \\State $\\theta_t \\gets \\theta_{t-1} + \\gamma g_t$\n    \\Else\n        \\State $\\theta_t \\gets \\theta_{t-1} - \\gamma g_t$\n    \\EndIf\n\\EndFor\n\\State \\textbf{return} $\\theta_t$\n\\end{algorithmic}\n\\end{algorithm}"</annotation></semantics></math></span>
<div class="ps-algorithm with-caption">
<p class="ps-line" style="text-indent:-0.6em;padding-left:0.6em;">
<span class="ps-keyword">Algorithm 3 </span>SGD in PyTorch</p>
<div class="ps-algorithmic with-linenum">
<div class="ps-block" style="margin-left:1.2em;">
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:0em;">1:</span><span style="font-weight:bold;">input:</span> <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>γ</mi></mrow><annotation encoding="application/x-tex">\gamma</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.0556em;">γ</span></span></span></span> (lr), <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>θ</mi><mn>0</mn></msub></mrow><annotation encoding="application/x-tex">\theta_0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0278em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> (params), <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>f</mi><mo stretchy="false">(</mo><mi>θ</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">f(\theta)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.1076em;">f</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.0278em;">θ</span><span class="mclose">)</span></span></span></span> (objective), <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>λ</mi></mrow><annotation encoding="application/x-tex">\lambda</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">λ</span></span></span></span> (weight decay),</p>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:0em;">2:</span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi></mrow><annotation encoding="application/x-tex">\mu</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">μ</span></span></span></span> (momentum), <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>τ</mi></mrow><annotation encoding="application/x-tex">\tau</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal" style="margin-right:0.1132em;">τ</span></span></span></span> (dampening), nesterov, maximize</p>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:0em;">3:</span><span class="ps-keyword">for </span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>t</mi><mo>=</mo><mn>1</mn></mrow><annotation encoding="application/x-tex">t = 1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6151em;"></span><span class="mord mathnormal">t</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">1</span></span></span></span> to <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi></mrow><annotation encoding="application/x-tex">...</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.1056em;"></span><span class="mord">...</span></span></span></span><span class="ps-keyword"> do</span></p>
<div class="ps-block" style="margin-left:0.6em;">
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-0.75em;">4:</span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>g</mi><mi>t</mi></msub><mo>←</mo><msub><mi mathvariant="normal">∇</mi><mi>θ</mi></msub><msub><mi>f</mi><mi>t</mi></msub><mo stretchy="false">(</mo><msub><mi>θ</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">g_t \gets \nabla_\theta f_t(\theta_{t-1})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0359em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord">∇</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.0278em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.1076em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0278em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></p>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-0.75em;">5:</span><span class="ps-keyword">if </span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>λ</mi><mo mathvariant="normal">≠</mo><mn>0</mn></mrow><annotation encoding="application/x-tex">\lambda \neq 0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">λ</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel"><span class="mrel"><span class="mord vbox"><span class="thinbox"><span class="rlap"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="inner"><span class="mord"><span class="mrel"></span></span></span><span class="fix"></span></span></span></span></span><span class="mspace nobreak"></span><span class="mrel">=</span></span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0</span></span></span></span><span class="ps-keyword"> then</span></p>
<div class="ps-block" style="margin-left:0.6em;">
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-1.5em;">6:</span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>g</mi><mi>t</mi></msub><mo>←</mo><msub><mi>g</mi><mi>t</mi></msub><mo>+</mo><mi>λ</mi><msub><mi>θ</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub></mrow><annotation encoding="application/x-tex">g_t \gets g_t + \lambda\theta_{t-1}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0359em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.7778em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0359em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.9028em;vertical-align:-0.2083em;"></span><span class="mord mathnormal">λ</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0278em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span></span></span></span></p>
</div>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-0.75em;">7:</span><span class="ps-keyword">end if</span></p>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-0.75em;">8:</span><span class="ps-keyword">if </span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo mathvariant="normal">≠</mo><mn>0</mn></mrow><annotation encoding="application/x-tex">\mu \neq 0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">μ</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel"><span class="mrel"><span class="mord vbox"><span class="thinbox"><span class="rlap"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="inner"><span class="mord"><span class="mrel"></span></span></span><span class="fix"></span></span></span></span></span><span class="mspace nobreak"></span><span class="mrel">=</span></span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0</span></span></span></span><span class="ps-keyword"> then</span></p>
<div class="ps-block" style="margin-left:0.6em;">
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-1.5em;">9:</span><span class="ps-keyword">if </span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>t</mi><mo>></mo><mn>1</mn></mrow><annotation encoding="application/x-tex">t > 1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6542em;vertical-align:-0.0391em;"></span><span class="mord mathnormal">t</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">></span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">1</span></span></span></span><span class="ps-keyword"> then</span></p>
<div class="ps-block" style="margin-left:0.6em;">
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-2.25em;">10:</span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>b</mi><mi>t</mi></msub><mo>←</mo><mi>μ</mi><msub><mi>b</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo>+</mo><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><mi>τ</mi><mo stretchy="false">)</mo><msub><mi>g</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">b_t \gets \mu b_{t-1} + (1-\tau)g_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.9028em;vertical-align:-0.2083em;"></span><span class="mord mathnormal">μ</span><span class="mord"><span class="mord mathnormal">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.1132em;">τ</span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0359em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></p>
</div>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-1.5em;">11:</span><span class="ps-keyword">else</span></p>
<div class="ps-block" style="margin-left:0.6em;">
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-2.25em;">12:</span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>b</mi><mi>t</mi></msub><mo>←</mo><msub><mi>g</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">b_t \gets g_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0359em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></p>
</div>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-1.5em;">13:</span><span class="ps-keyword">end if</span></p>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-1.5em;">14:</span><span class="ps-keyword">if </span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>nesterov</mtext></mrow><annotation encoding="application/x-tex">\text{nesterov}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6151em;"></span><span class="mord text"><span class="mord">nesterov</span></span></span></span></span><span class="ps-keyword"> then</span></p>
<div class="ps-block" style="margin-left:0.6em;">
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-2.25em;">15:</span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>g</mi><mi>t</mi></msub><mo>←</mo><msub><mi>g</mi><mi>t</mi></msub><mo>+</mo><mi>μ</mi><msub><mi>b</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">g_t \gets g_t + \mu b_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0359em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.7778em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0359em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">μ</span><span class="mord"><span class="mord mathnormal">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></p>
</div>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-1.5em;">16:</span><span class="ps-keyword">else</span></p>
<div class="ps-block" style="margin-left:0.6em;">
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-2.25em;">17:</span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>g</mi><mi>t</mi></msub><mo>←</mo><msub><mi>b</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">g_t \gets b_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0359em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></p>
</div>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-1.5em;">18:</span><span class="ps-keyword">end if</span></p>
</div>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-0.75em;">19:</span><span class="ps-keyword">end if</span></p>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-0.75em;">20:</span><span class="ps-keyword">if </span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>maximize</mtext></mrow><annotation encoding="application/x-tex">\text{maximize}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6679em;"></span><span class="mord text"><span class="mord">maximize</span></span></span></span></span><span class="ps-keyword"> then</span></p>
<div class="ps-block" style="margin-left:0.6em;">
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-1.5em;">21:</span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>θ</mi><mi>t</mi></msub><mo>←</mo><msub><mi>θ</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo>+</mo><mi>γ</mi><msub><mi>g</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">\theta_t \gets \theta_{t-1} + \gamma g_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0278em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.9028em;vertical-align:-0.2083em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0278em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.0556em;">γ</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0359em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></p>
</div>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-0.75em;">22:</span><span class="ps-keyword">else</span></p>
<div class="ps-block" style="margin-left:0.6em;">
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-1.5em;">23:</span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>θ</mi><mi>t</mi></msub><mo>←</mo><msub><mi>θ</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo>−</mo><mi>γ</mi><msub><mi>g</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">\theta_t \gets \theta_{t-1} - \gamma g_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0278em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.9028em;vertical-align:-0.2083em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0278em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.0556em;">γ</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0359em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></p>
</div>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:-0.75em;">24:</span><span class="ps-keyword">end if</span></p>
</div>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:0em;">25:</span><span class="ps-keyword">end for</span></p>
<p class="ps-line ps-code">
<span class="ps-linenum" style="left:0em;">26:</span><span style="font-weight:bold;">return</span> <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>θ</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">\theta_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0278em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></p>
</div>
</div>
</div>
</div>

## [[thoughts/knowledge distillation]]

examples on CIFAR

```python title="distill.py" path="thoughts/scripts/distill.py"
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if the current `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
# is available, and if not, use the CPU
device = (
  torch.accelerator.current_accelerator().type
  if torch.accelerator.is_available()
  else 'cpu'
)
print(f'Using {device} device')

# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(
  root='./data', train=True, download=True, transform=transforms_cifar
)
test_dataset = datasets.CIFAR10(
  root='./data', train=False, download=True, transform=transforms_cifar
)


# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):
  def __init__(self, num_classes=10):
    super(DeepNN, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(3, 128, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.Conv2d(128, 64, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
      nn.Conv2d(64, 64, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.Conv2d(64, 32, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.classifier = nn.Sequential(
      nn.Linear(2048, 512),
      nn.ReLU(),
      nn.Dropout(0.1),
      nn.Linear(512, num_classes),
    )

  def forward(self, x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)
    return x


# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
  def __init__(self, num_classes=10):
    super(LightNN, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(3, 16, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
      nn.Conv2d(16, 16, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.classifier = nn.Sequential(
      nn.Linear(1024, 256),
      nn.ReLU(),
      nn.Dropout(0.1),
      nn.Linear(256, num_classes),
    )

  def forward(self, x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)
    return x


def train(model, train_loader, epochs, learning_rate, device):
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)

  model.train()

  for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
      # inputs: A collection of batch_size images
      # labels: A vector of dimensionality batch_size with integers denoting class of each image
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()
      outputs = model(inputs)

      # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
      # labels: The actual labels of the images. Vector of dimensionality batch_size
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      running_loss += loss.item()

    print(
      f'Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader)}'
    )


def test(model, test_loader, device):
  model.to(device)
  model.eval()

  correct = 0
  total = 0

  with torch.no_grad():
    for inputs, labels in test_loader:
      inputs, labels = inputs.to(device), labels.to(device)

      outputs = model(inputs)
      _, predicted = torch.max(outputs.data, 1)

      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  accuracy = 100 * correct / total
  print(f'Test Accuracy: {accuracy:.2f}%')
  return accuracy


def train_knowledge_distillation(
  teacher,
  student,
  train_loader,
  epochs,
  learning_rate,
  T,
  soft_target_loss_weight,
  ce_loss_weight,
  device,
):
  ce_loss = nn.CrossEntropyLoss()
  optimizer = optim.Adam(student.parameters(), lr=learning_rate)

  teacher.eval()  # Teacher set to evaluation mode
  student.train()  # Student to train mode

  for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()

      # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
      with torch.no_grad():
        teacher_logits = teacher(inputs)

      # Forward pass with the student model
      student_logits = student(inputs)

      # Soften the student logits by applying softmax first and log() second
      soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
      soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

      # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
      soft_targets_loss = (
        torch.sum(soft_targets * (soft_targets.log() - soft_prob))
        / soft_prob.size()[0]
        * (T**2)
      )

      # Calculate the true label loss
      label_loss = ce_loss(student_logits, labels)

      # Weighted sum of the two losses
      loss = (
        soft_target_loss_weight * soft_targets_loss
        + ce_loss_weight * label_loss
      )

      loss.backward()
      optimizer.step()

      running_loss += loss.item()

    print(
      f'Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader)}'
    )


if __name__ == '__main__':
  torch.manual_seed(42)
  nn_deep = DeepNN(num_classes=10).to(device)
  train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
  test_accuracy_deep = test(nn_deep, test_loader, device)

  # Instantiate the lightweight network:
  torch.manual_seed(42)
  nn_light = LightNN(num_classes=10).to(device)
  print(
    'Norm of 1st layer of nn_light:',
    torch.norm(nn_light.features[0].weight).item(),
  )
  print(
    'Norm of 1st layer of new_nn_light:',
    torch.norm(new_nn_light.features[0].weight).item(),
  )
  train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
  test_accuracy_light_ce = test(nn_light, test_loader, device)

  print(f'Teacher accuracy: {test_accuracy_deep:.2f}%')
  print(f'Student accuracy: {test_accuracy_light_ce:.2f}%')

  # Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
  train_knowledge_distillation(
    teacher=nn_deep,
    student=new_nn_light,
    train_loader=train_loader,
    epochs=10,
    learning_rate=0.001,
    T=2,
    soft_target_loss_weight=0.25,
    ce_loss_weight=0.75,
    device=device,
  )
  test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

  # Compare the student test accuracy with and without the teacher, after distillation
  print(f'Teacher accuracy: {test_accuracy_deep:.2f}%')
  print(f'Student accuracy without teacher: {test_accuracy_light_ce:.2f}%')
  print(f'Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%')

```

### Cosine loss minimisation run

assumption: the teacher network will have a better internal [[thoughts/representations]] comparing to student’s weights. Thus we need to artificially push the students’ weight to “mimic” the teachers’ weights.

We will apply `CosineEmbeddingLoss` such that students’ internal representation would be a permutation of the teacher’s:

$$
\text{loss}(x,y) = \begin{cases}
1 - \cos(x_1, x_2), & \text{if } y = 1 \\
\max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1
\end{cases}
$$

The updated loops as follow [^internal]:

[^internal]: Naturally, we have to update the hidden representation:

    ```python
    sample_input = torch.randn(128, 3, 32, 32).to(
      device
    )  # Batch size: 128, Filters: 3, Image size: 32x32
    logits, hidden_representation = modified_nn_light(sample_input)

    print('Student logits shape:', logits.shape)  # batch_size x total_classes
    print(
      'Student hidden representation shape:', hidden_representation.shape
    )  # batch_size x hidden_representation_size

    logits, hidden_representation = modified_nn_deep(sample_input)

    print('Teacher logits shape:', logits.shape)  # batch_size x total_classes
    print(
      'Teacher hidden representation shape:', hidden_representation.shape
    )  # batch_size x hidden_representation_size
    ```

```python title="modified\_deep\_cosine.py" path="thoughts/scripts/modified\_deep\_cosine.py"
class ModifiedDeepNNCosine(nn.Module):
  def __init__(self, num_classes=10):
    super(ModifiedDeepNNCosine, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(3, 128, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.Conv2d(128, 64, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
      nn.Conv2d(64, 64, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.Conv2d(64, 32, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.classifier = nn.Sequential(
      nn.Linear(2048, 512),
      nn.ReLU(),
      nn.Dropout(0.1),
      nn.Linear(512, num_classes),
    )

  def forward(self, x):
    x = self.features(x)
    flattened_conv_output = torch.flatten(x, 1)
    x = self.classifier(flattened_conv_output)
    flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(
      flattened_conv_output, 2
    )
    return x, flattened_conv_output_after_pooling


# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedLightNNCosine(nn.Module):
  def __init__(self, num_classes=10):
    super(ModifiedLightNNCosine, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(3, 16, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
      nn.Conv2d(16, 16, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.classifier = nn.Sequential(
      nn.Linear(1024, 256),
      nn.ReLU(),
      nn.Dropout(0.1),
      nn.Linear(256, num_classes),
    )

  def forward(self, x):
    x = self.features(x)
    flattened_conv_output = torch.flatten(x, 1)
    x = self.classifier(flattened_conv_output)
    return x, flattened_conv_output


def train_cosine_loss(
  teacher,
  student,
  train_loader,
  epochs,
  learning_rate,
  hidden_rep_loss_weight,
  ce_loss_weight,
  device,
):
  ce_loss = nn.CrossEntropyLoss()
  cosine_loss = nn.CosineEmbeddingLoss()
  optimizer = optim.Adam(student.parameters(), lr=learning_rate)

  teacher.to(device)
  student.to(device)
  teacher.eval()  # Teacher set to evaluation mode
  student.train()  # Student to train mode

  for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()

      # Forward pass with the teacher model and keep only the hidden representation
      with torch.no_grad():
        _, teacher_hidden_representation = teacher(inputs)

      # Forward pass with the student model
      student_logits, student_hidden_representation = student(inputs)

      # Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.
      hidden_rep_loss = cosine_loss(
        student_hidden_representation,
        teacher_hidden_representation,
        target=torch.ones(inputs.size(0)).to(device),
      )

      # Calculate the true label loss
      label_loss = ce_loss(student_logits, labels)

      # Weighted sum of the two losses
      loss = (
        hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss
      )

      loss.backward()
      optimizer.step()

      running_loss += loss.item()

    print(
      f'Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader)}'
    )


def test_multiple_outputs(model, test_loader, device):
  model.to(device)
  model.eval()

  correct = 0
  total = 0

  with torch.no_grad():
    for inputs, labels in test_loader:
      inputs, labels = inputs.to(device), labels.to(device)

      outputs, _ = model(inputs)  # Disregard the second tensor of the tuple
      _, predicted = torch.max(outputs.data, 1)

      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  accuracy = 100 * correct / total
  print(f'Test Accuracy: {accuracy:.2f}%')
  return accuracy


if __name__ == '__main__':
  # We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
  modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device)
  modified_nn_deep.load_state_dict(nn_deep.state_dict())

  # Once again ensure the norm of the first layer is the same for both networks
  print(
    'Norm of 1st layer for deep_nn:',
    torch.norm(nn_deep.features[0].weight).item(),
  )
  print(
    'Norm of 1st layer for modified_deep_nn:',
    torch.norm(modified_nn_deep.features[0].weight).item(),
  )

  # Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.
  torch.manual_seed(42)
  modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)
  print(
    'Norm of 1st layer:',
    torch.norm(modified_nn_light.features[0].weight).item(),
  )

```

