diff --git a/notebooks/core/vector.ipynb b/notebooks/core/vector.ipynb new file mode 100644 index 0000000..aa079c2 --- /dev/null +++ b/notebooks/core/vector.ipynb @@ -0,0 +1,502 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6d23ce3f", + "metadata": {}, + "source": [ + "# Tutorial on quantEM `Vector` class\n", + "\n", + "This tutorial demonstrates how the quantEM `Vector` module works \n", + "\n", + "Colin Ophus and Arthur McCray\n", + "March 5, 2026" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ee8a094b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "d:\\code\\quantem\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import quantem as em\n", + "from quantem.core.datastructures import Vector \n" + ] + }, + { + "cell_type": "markdown", + "id": "18674f63", + "metadata": {}, + "source": [ + "## Creating a Vector\n", + "\n", + "A `Vector` stores ragged per-cell data on a fixed grid. Each cell holds a variable number of rows, one value per named field.\n", + "\n", + "Create an empty `Vector` with `from_shape`, then assign cell data with `[]`:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "204d3478", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "quantem.Vector, shape=(20, 30), name=diffraction_vectors\n", + " fields = ['kx', 'ky', 'intensity']\n", + " units: ['A^-1', 'A^-1', 'counts']\n" + ] + } + ], + "source": [ + "Nx, Ny = 20, 30 # fixed-grid dimensions (e.g. scan positions)\n", + "\n", + "v = Vector.from_shape(\n", + " shape=(Nx, Ny),\n", + " fields=(\"kx\", \"ky\", \"intensity\"),\n", + " units=(\"A^-1\", \"A^-1\", \"counts\"),\n", + " name=\"diffraction_vectors\",\n", + ")\n", + "\n", + "# Assign each cell a 2D array of shape (n_rows, num_fields)\n", + "rng = np.random.default_rng(42)\n", + "for rx in range(Nx):\n", + " for ry in range(Ny):\n", + " n = rng.integers(5, 20)\n", + " phi = rng.random(n) * 2 * np.pi\n", + " r = 10 + rng.standard_normal(n) * 2\n", + " v[rx, ry] = np.column_stack((r * np.cos(phi), r * np.sin(phi), rng.random(n)))\n", + "\n", + "print(v)" + ] + }, + { + "cell_type": "markdown", + "id": "b8dd9fa4", + "metadata": {}, + "source": [ + "Alternatively, create from existing nested data with `from_data`. Cells can be lists, tuples, or arrays:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c47d7e96", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "quantem.Vector, shape=(2,), name=example\n", + " fields = ['x', 'y']\n", + " units: ['m', 'm']\n", + "\n", + "cell 0:\n", + " [[1. 2.]\n", + " [3. 4.]]\n", + "\n", + "cell 1:\n", + " [[ 5. 6.]\n", + " [ 7. 8.]\n", + " [ 9. 10.]]\n" + ] + } + ], + "source": [ + "v_small = Vector.from_data(\n", + " data=[\n", + " np.array([[1.0, 2.0], [3.0, 4.0]]), # cell 0: 2 rows\n", + " np.array([[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]), # cell 1: 3 rows\n", + " ],\n", + " fields=[\"x\", \"y\"],\n", + " units=[\"m\", \"m\"],\n", + " name=\"example\",\n", + ")\n", + "print(v_small)\n", + "print(\"\\ncell 0:\\n\", v_small[0].array)\n", + "print(\"\\ncell 1:\\n\", v_small[1].array)" + ] + }, + { + "cell_type": "markdown", + "id": "66342311", + "metadata": {}, + "source": [ + "## Properties" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cc24d33c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shape: (20, 30)\n", + "fields: ['kx', 'ky', 'intensity']\n", + "units: ['A^-1', 'A^-1', 'counts']\n", + "num_fields: 3\n", + "num_cells: 600\n", + "total_rows: 7243\n", + "dtype: float64\n", + "name: diffraction_vectors\n" + ] + } + ], + "source": [ + "print(\"shape: \", v.shape)\n", + "print(\"fields: \", v.fields)\n", + "print(\"units: \", v.units)\n", + "print(\"num_fields: \", v.num_fields)\n", + "print(\"num_cells: \", v.num_cells)\n", + "print(\"total_rows: \", v.total_rows)\n", + "print(\"dtype: \", v.dtype)\n", + "print(\"name: \", v.name)" + ] + }, + { + "cell_type": "markdown", + "id": "8d545fee", + "metadata": {}, + "source": [ + "## Fixed-grid indexing\n", + "\n", + "`[]` selects along the fixed-grid axes — the same as NumPy indexing on a regular array. It always returns a `Vector` view over shared storage.\n", + "\n", + "- Integer index → 0D `Vector`; access the cell array with `.array`\n", + "- Slice/fancy index → sub-grid `Vector`\n", + "- `.flatten()` concatenates all selected cells row-wise into a 2D NumPy array" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "acfa59aa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cell shape: ()\n", + "cell array:\n", + " [[ -8.68517809 3.50970304 0.82276161]\n", + " [ 6.28492484 -7.73490805 0.4434142 ]\n", + " [ -2.69304819 -7.84451847 0.22723872]\n", + " [ 9.75950493 6.55906592 0.55458479]\n", + " [ 11.42029689 -1.76304781 0.06381726]\n", + " [ 0.70859257 -10.10725307 0.82763117]]\n", + "\n", + "row shape: (30,)\n", + "flattened shape: (384, 3)\n" + ] + } + ], + "source": [ + "# 0D selection → use .array to get the NumPy array for that cell\n", + "cell = v[0, 0]\n", + "print(\"cell shape:\", cell.shape) # () means scalar / 0D\n", + "print(\"cell array:\\n\", cell.array)\n", + "\n", + "# Slice along one or both axes → returns a sub-grid Vector\n", + "row = v[0, :] # first row, all columns\n", + "print(\"\\nrow shape:\", row.shape)\n", + "\n", + "# Flatten all selected cells into one 2D array\n", + "flat = v[0, :].flatten()\n", + "print(\"flattened shape:\", flat.shape) # (total_rows, num_fields)\n", + "\n", + "# Copy data from one region to another (write-through)\n", + "v[1, :] = v[0, :]" + ] + }, + { + "cell_type": "markdown", + "id": "603933ad", + "metadata": {}, + "source": [ + "## Field selection & arithmetic\n", + "\n", + "`select_fields(...)` returns a **write-through view** over a subset of fields. Changes made through the view are reflected in the parent `Vector`.\n", + "\n", + "Arithmetic operators (`+`, `-`, `*`, `/`, `**`, `%`, `//`, unary `-`, `abs`) all work on `Vector` objects and return new `Vector` instances. In-place operators (`+=`, `*=`, etc.) modify the backing data directly." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "34d4dd53", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "quantem.Vector, shape=(20, 30), name=r_squared\n", + " fields = ['r_squared']\n", + " units: ['A^-1']\n", + "r² first 5 values: [ 434.1351322 564.92961994 243.58684533 1172.46354937 954.56348915]\n", + "\n", + "v fields after modifications: ['kx', 'ky', 'intensity']\n", + "kx range: [0.40, 31.14]\n" + ] + } + ], + "source": [ + "v2 = v.copy() \n", + "kx = v2.select_fields(\"kx\")\n", + "ky = v2.select_fields(\"ky\")\n", + "intensity = v2.select_fields(\"intensity\")\n", + "\n", + "# In-place: modifies v directly\n", + "kx += 16\n", + "ky += 16\n", + "\n", + "# Arithmetic between field views\n", + "r_squared = kx**2 + ky**2 # new Vector, field value is \"kx\" same as kx vector\n", + "r_squared.rename_fields({\"kx\": \"r_squared\"})\n", + "r_squared.name = \"r_squared\"\n", + "print(r_squared)\n", + "print(\"r² first 5 values:\", r_squared.flatten()[:5, 0])\n", + "\n", + "# Scale intensity in-place by a per-row factor\n", + "scale = kx.flatten() / r_squared.flatten()\n", + "intensity.set_flattened(intensity.flatten() * scale)\n", + "\n", + "# Assign using [...]\n", + "kx[...] = np.abs(kx.flatten()) # equivalent to abs(kx)[...] = ...\n", + "\n", + "print(\"\\nv fields after modifications:\", v2.fields)\n", + "print(\"kx range: [{:.2f}, {:.2f}]\".format(kx.flatten().min(), kx.flatten().max()))" + ] + }, + { + "cell_type": "markdown", + "id": "c377c018", + "metadata": {}, + "source": [ + "## NumPy ufunc support\n", + "\n", + "NumPy elementwise ufuncs that return a numpy array work directly on `Vector` objects via `__array_ufunc__`. Multi-output ufuncs (e.g. `np.modf`) return tuples of `Vector`. Reduction-based functions (such as `np.mean(arr)` or `np.sum(arr)`) do not work. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "69aec78e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kx[:2]:\n", + " [[-8.68517809]\n", + " [ 6.28492484]]\n", + "\n", + "sin(kx)[:2]:\n", + " [[-0.67399238]\n", + " [ 0.00173953]]\n", + "\n", + "maximum(kx, 10)[:2]:\n", + " [[10.]\n", + " [10.]]\n", + "\n", + "modf frac[:2]: [-0.68517809 0.28492484]\n", + "modf whole[:2]: [-8. 6.]\n" + ] + } + ], + "source": [ + "kx = v.select_fields(\"kx\")\n", + "ky = v.select_fields(\"ky\")\n", + "intensity = v.select_fields(\"intensity\")\n", + "print(\"kx[:2]:\\n\", kx.flatten()[:2])\n", + "\n", + "print(\"\\nsin(kx)[:2]:\\n\", np.sin(kx).flatten()[:2])\n", + "print(\"\\nmaximum(kx, 10)[:2]:\\n\", np.maximum(kx, 10).flatten()[:2]) # type: ignore[call-overload]\n", + "\n", + "# Multi-output: modf returns (fractional, integer) as a tuple of Vectors\n", + "frac, whole = np.modf(kx)\n", + "print(\"\\nmodf frac[:2]:\", frac.flatten()[:2, 0])\n", + "print(\"modf whole[:2]:\", whole.flatten()[:2, 0])" + ] + }, + { + "cell_type": "markdown", + "id": "7e8a30a1", + "metadata": {}, + "source": [ + "## Row-wise updates with `set_flattened`\n", + "\n", + "`set_flattened(values)` writes back into all selected cells without changing per-cell row counts. This is the natural companion to `flatten()` for applying row-wise NumPy transforms." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "17950b91", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rows zeroed: 3599 / 7229\n" + ] + } + ], + "source": [ + "v2 = v.copy()\n", + "kx = v2.select_fields(\"kx\")\n", + "ky = v2.select_fields(\"ky\")\n", + "\n", + "# Zero-out kx for any row within radius 2 of the origin\n", + "mask = (kx.flatten()**2 + ky.flatten()**2) < 100 # shape (total_rows, 1)\n", + "kx.set_flattened(np.where(mask, 0.0, kx.flatten()))\n", + "\n", + "print(f\"rows zeroed: {mask.sum()} / {v2.total_rows}\")" + ] + }, + { + "cell_type": "markdown", + "id": "7359f3c4", + "metadata": {}, + "source": [ + "## Schema operations\n", + "\n", + "`add_fields` and `remove_fields` modify the field schema for the whole `Vector`. `append_rows` adds rows to a single cell." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6f0bd9ce", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fields after add: ['kx', 'ky', 'intensity', 'r']\n", + "fields after remove: ['kx', 'ky', 'intensity']\n", + "\n", + "cell [0,0] rows: 6 → 7\n" + ] + } + ], + "source": [ + "kx = v.select_fields(\"kx\")\n", + "ky = v.select_fields(\"ky\")\n", + "\n", + "# Add a derived field with initial values\n", + "v.add_fields(\"r\", values=np.sqrt(kx**2 + ky**2), units=\"A^-1\")\n", + "print(\"fields after add:\", v.fields)\n", + "\n", + "# Remove it again\n", + "v.remove_fields(\"r\")\n", + "print(\"fields after remove:\", v.fields)\n", + "\n", + "# Append rows to a single cell\n", + "before = v[0, 0].array.shape[0]\n", + "v.append_rows((0, 0), np.array([[1.0, 2.0, 0.5]]))\n", + "after = v[0, 0].array.shape[0]\n", + "print(f\"\\ncell [0,0] rows: {before} → {after}\")" + ] + }, + { + "cell_type": "markdown", + "id": "83272acc", + "metadata": {}, + "source": [ + "## File I/O\n", + "\n", + "`Vector` can be saved and loaded using the standard quantEM I/O interface." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b9eb3bdf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loaded: quantem.Vector, shape=(20, 30), name=diffraction_vectors\n", + " fields = ['kx', 'ky', 'intensity']\n", + " units: ['A^-1', 'A^-1', 'counts']\n", + "Round-trip OK\n" + ] + } + ], + "source": [ + "import tempfile\n", + "\n", + "with tempfile.NamedTemporaryFile(suffix=\".zip\", delete=False) as tmp:\n", + " path = tmp.name\n", + "\n", + "try:\n", + " v.save(path, mode=\"o\")\n", + " v_loaded = em.io.load(path)\n", + "\n", + " # Verify round-trip\n", + " print(\"loaded:\", v_loaded)\n", + " np.testing.assert_array_equal(v[3, 5].array, v_loaded[3, 5].array)\n", + " print(\"Round-trip OK\")\n", + "finally:\n", + " import os\n", + " if os.path.exists(path):\n", + " os.remove(path)" + ] + }, + { + "cell_type": "markdown", + "id": "74c8048a", + "metadata": {}, + "source": [ + "-- end notebook --" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv (3.12.10)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}