diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index 4f0ca12..6d393a2 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,6 +1,6 @@ -source-sha: 450bafecd23db638602150b47f4272b98aad3146 -synced-at: "2026-04-14" +source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28 +synced-at: "2026-05-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 7 -tool-version: 0.14.1 +tool-version: 0.15.0 diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index e904fbe..66ae445 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,6 +1,6 @@ -source-sha: 450bafecd23db638602150b47f4272b98aad3146 -synced-at: "2026-04-14" +source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28 +synced-at: "2026-05-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 3 -tool-version: 0.14.1 +tool-version: 0.15.0 diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index a00de1b..6301535 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -24,7 +24,7 @@ translation: JAX as a NumPy Replacement::Differences::A Workaround: راه‌حل جایگزین Functional Programming: برنامه‌نویسی تابعی Functional Programming::Pure functions: توابع خالص - Functional Programming::Examples: مثال‌ها + Functional Programming::Examples -- Pure and Impure: مثال‌ها -- خالص و ناخالص Functional Programming::Why Functional Programming?: چرا برنامه‌نویسی تابعی؟ Random numbers: اعداد تصادفی Random numbers::NumPy / MATLAB Approach: رویکرد NumPy / MATLAB @@ -356,19 +356,20 @@ a * وضعیت سراسری را تغییر نمی‌دهد * داده‌های ارسال شده به تابع را تغییر نمی‌دهد (داده‌های تغییرناپذیر) -### مثال‌ها +### مثال‌ها -- خالص و ناخالص -در اینجا مثالی از یک تابع *غیرخالص* آورده شده است +در اینجا مثالی از یک تابع *ناخالص* آورده شده است ```{code-cell} ipython3 tax_rate = 0.1 -prices = [10.0, 20.0] def add_tax(prices): for i, price in enumerate(prices): prices[i] = price * (1 + tax_rate) - print('Post-tax prices: ', prices) - return prices + +prices = [10.0, 20.0] +add_tax(prices) +prices ``` این تابع نمی‌تواند خالص باشد زیرا @@ -379,15 +380,21 @@ def add_tax(prices): در اینجا یک نسخه *خالص* آورده شده است ```{code-cell} ipython3 -tax_rate = 0.1 -prices = (10.0, 20.0) def add_tax_pure(prices, tax_rate): new_prices = [price * (1 + tax_rate) for price in prices] return new_prices + +tax_rate = 0.1 +prices = (10.0, 20.0) +after_tax_prices = add_tax_pure(prices, tax_rate) +after_tax_prices ``` -این نسخه خالص تمام وابستگی‌ها را از طریق آرگومان‌های تابع صریح می‌کند و هیچ وضعیت خارجی را تغییر نمی‌دهد. +این نسخه خالص است زیرا + +* تمام وابستگی‌ها از طریق آرگومان‌های تابع صریح هستند +* و هیچ وضعیت خارجی را تغییر نمی‌دهد ### چرا برنامه‌نویسی تابعی؟ @@ -437,7 +444,7 @@ print(np.random.randn(2)) * غیرقطعی است: ورودی‌های یکسان، خروجی‌های متفاوت * دارای عوارض جانبی است: وضعیت مولد اعداد تصادفی سراسری را تغییر می‌دهد -در موازی‌سازی خطرناک است --- باید با دقت کنترل کرد که در هر رشته چه اتفاقی می‌افتد! +این در موازی‌سازی خطرناک است --- باید با دقت کنترل کرد که در هر رشته چه اتفاقی می‌افتد. ### JAX @@ -554,7 +561,11 @@ plt.show() تابع زیر `k` ماتریس تصادفی `n x n` (شبه) مستقل را با استفاده از `split` تولید می‌کند. ```{code-cell} ipython3 -def gen_random_matrices(key, n=2, k=3): +def gen_random_matrices( + key, # JAX key for random numbers + n=2, # Matrices will be n x n + k=3 # Number of matrices to generate + ): matrices = [] for _ in range(k): key, subkey = jax.random.split(key) @@ -576,7 +587,7 @@ gen_random_matrices(key) ### مزایا -صریح بودن JAX مزایای قابل توجهی به همراه دارد: +همان‌طور که در بالا ذکر شد، این صراحت ارزشمند است: * تکرارپذیری: با استفاده مجدد از کلیدها، تکرار نتایج آسان است * موازی‌سازی: کنترل آنچه در رشته‌های جداگانه اتفاق می‌افتد @@ -657,7 +668,14 @@ with qe.Timer(): نتیجه مشابه مثال `cos` است --- JAX سریع‌تر است، به ویژه در اجرای دوم پس از کامپایل JIT. -اما همچنان از اجرای eager استفاده می‌کنیم --- حافظه و خواندن/نوشتن زیاد. +این به این دلیل است که عملیات‌های آرایه‌ای منفرد روی GPU موازی‌سازی می‌شوند. + +اما همچنان از اجرای eager استفاده می‌کنیم + +* حافظه زیاد به دلیل آرایه‌های میانی +* خواندن/نوشتن حافظه زیاد + +همچنین، هسته‌های جداگانه زیادی روی GPU راه‌اندازی می‌شوند. ### کامپایل کل تابع @@ -691,7 +709,8 @@ with qe.Timer(): * بهینه‌سازی تهاجمی بر اساس کل دنباله محاسباتی * حذف چندین فراخوانی به شتاب‌دهنده سخت‌افزاری -* عدم ایجاد آرایه‌های میانی + +ردپای حافظه نیز بسیار کمتر است --- عدم ایجاد آرایه‌های میانی. اتفاقاً، نحو رایج‌تر هنگام هدف قرار دادن یک تابع برای کامپایلر JIT این است diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index d560e90..1c5f849 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -13,6 +13,7 @@ translation: Vectorized operations: عملیات برداری شده Vectorized operations::Problem Statement: بیان مسئله Vectorized operations::NumPy vectorization: برداری‌سازی NumPy + Vectorized operations::Memory Issues: مشکلات حافظه Vectorized operations::A Comparison with Numba: مقایسه با Numba Vectorized operations::Parallelized Numba: Numba موازی شده Vectorized operations::Vectorized code with JAX: کد برداری شده با JAX @@ -146,16 +147,33 @@ for x in grid: بیایید به NumPy تغییر دهیم و از یک شبکه بزرگتر استفاده کنیم +```{code-cell} ipython3 +grid = np.linspace(-3, 3, 3_000) # Large grid +``` + +به عنوان اولین گام برداری‌سازی ممکن است چیزی شبیه به این امتحان کنیم + +```{code-cell} ipython3 +# Large grid +z = np.max(f(grid, grid)) # This is wrong! +``` + +مشکل اینجا این است که `f(grid, grid)` از حلقه تودرتو پیروی نمی‌کند. + +از نظر شکل بالا، این کد فقط مقادیر `f` را روی قطر محاسبه می‌کند. + +برای اینکه NumPy را مجبور کنیم `f(x,y)` را روی هر جفت `x,y` محاسبه کند، باید از `np.meshgrid` استفاده کنیم. + در اینجا از `np.meshgrid` برای ایجاد شبکه‌های ورودی دوبعدی `x` و `y` استفاده می‌کنیم به گونه‌ای که `f(x, y)` تمام ارزیابی‌ها را روی شبکه حاصلضرب تولید می‌کند. ```{code-cell} ipython3 # Large grid grid = np.linspace(-3, 3, 3_000) -x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid +x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid with qe.Timer(): - z_max_numpy = np.max(f(x, y)) + z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works ``` در نسخه برداری شده، تمام حلقه‌ها در کد کامپایل شده انجام می‌شوند. @@ -168,9 +186,29 @@ with qe.Timer(): print(f"NumPy result: {z_max_numpy:.6f}") ``` +### مشکلات حافظه + +پس ما راه‌حل صحیح را در زمان معقول داریم --- اما مصرف حافظه بسیار زیاد است. + +در حالی که آرایه‌های تخت حافظه کمی دارند + +```{code-cell} ipython3 +grid.nbytes +``` + +شبکه‌های mesh دوبعدی هستند و از این رو از نظر حافظه بسیار فشرده‌اند + +```{code-cell} ipython3 +x_mesh.nbytes + y_mesh.nbytes +``` + +علاوه بر این، اجرای بلادرنگ NumPy آرایه‌های میانی زیادی با همان اندازه ایجاد می‌کند! + +این نوع مصرف حافظه می‌تواند یک مشکل بزرگ در محاسبات تحقیقاتی واقعی باشد. + ### مقایسه با Numba -حالا بیایید ببینیم آیا می‌توانیم با استفاده از Numba با یک حلقه ساده به عملکرد بهتری دست یابیم. +بیایید ببینیم آیا می‌توانیم با استفاده از Numba با یک حلقه ساده به عملکرد بهتری دست یابیم. ```{code-cell} ipython3 @numba.jit @@ -201,13 +239,13 @@ with qe.Timer(): compute_max_numba(grid) ``` -بسته به دستگاه شما، نسخه Numba ممکن است کندتر یا سریعتر از NumPy باشد. +توجه کنید که تقریباً هیچ حافظه‌ای استفاده نمی‌کنیم --- فقط به `grid` یک‌بعدی نیاز داریم. -در اکثر موارد، Numba کمی بهتر است. +علاوه بر این، سرعت اجرا خوب است. -از یک طرف، NumPy محاسبات کارآمد را با مقداری چندنخی ترکیب می‌کند که مزیتی فراهم می‌کند. +در اکثر دستگاه‌ها، نسخه Numba تا حدودی سریعتر از NumPy خواهد بود. -از طرف دیگر، روال Numba از حافظه بسیار کمتری استفاده می‌کند، زیرا ما فقط با یک شبکه یک‌بعدی کار می‌کنیم. +دلیل آن کد ماشین کارآمد به علاوه خواندن و نوشتن کمتر حافظه است. ### Numba موازی شده @@ -301,25 +339,11 @@ with qe.Timer(): ### JAX به علاوه vmap -یک مشکل با کد NumPy و کد JAX وجود دارد: - -در حالی که آرایه‌های تخت حافظه کمی دارند - -```{code-cell} ipython3 -grid.nbytes -``` - -شبکه‌های mesh فشرده از نظر حافظه هستند +چون از `jax.jit` در بالا استفاده کردیم، از ایجاد بسیاری از آرایه‌های میانی اجتناب کردیم. -```{code-cell} ipython3 -x_mesh.nbytes + y_mesh.nbytes -``` +اما همچنان آرایه‌های بزرگ `z_max`، `x_mesh` و `y_mesh` را ایجاد می‌کنیم. -این استفاده اضافی از حافظه می‌تواند یک مشکل بزرگ در محاسبات تحقیقاتی واقعی باشد. - -خوشبختانه، JAX رویکرد متفاوتی را با استفاده از [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) می‌پذیرد. - -ایده `vmap` این است که برداری‌سازی را به مراحل تقسیم کند و تابعی که روی مقادیر تکی عمل می‌کند را به تابعی تبدیل کند که روی آرایه‌ها عمل می‌کند. +خوشبختانه، می‌توانیم با استفاده از [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) از این اجتناب کنیم. در اینجا نحوه اعمال آن به مسئله ما آمده است. @@ -327,13 +351,13 @@ x_mesh.nbytes + y_mesh.nbytes @jax.jit def compute_max_vmap(grid): # Construct a function that takes the max over all x for given y - f_vec_x_max = lambda y: jnp.max(f(grid, y)) + compute_column_max = lambda y: jnp.max(f(grid, y)) # Vectorize the function so we can call on all y simultaneously - f_vec_max = jax.vmap(f_vec_x_max) - # Compute the max across x at every y - maxes = f_vec_max(grid) - # Compute the max of the maxes and return - return jnp.max(maxes) + vectorized_compute_column_max = jax.vmap(compute_column_max) + # Compute the column max at every row + column_maxes = vectorized_compute_column_max(grid) + # Compute the max of the column maxes and return + return jnp.max(column_maxes) ``` توجه کنید که هرگز @@ -344,6 +368,8 @@ def compute_max_vmap(grid): را نمی‌سازیم. +مانند Numba، فقط از آرایه تخت `grid` استفاده می‌کنیم. + و چون همه چیز زیر یک `@jax.jit` واحد قرار دارد، کامپایلر می‌تواند تمام عملیات را در یک kernel بهینه ادغام کند. بیایید آن را امتحان کنیم. @@ -374,13 +400,11 @@ with qe.Timer(): هم از نظر سرعت (از طریق JIT-compilation و موازی‌سازی) و هم از نظر کارایی حافظه (از طریق vmap) بر NumPy غلبه می‌کند. -علاوه بر این، رویکرد `vmap` گاهی اوقات می‌تواند منجر به کد به طور قابل توجهی واضح‌تری شود. +همچنین هنگام اجرا روی GPU بر Numba نیز غلبه می‌کند. -در حالی که Numba چشمگیر است، زیبایی JAX این است که با عملیات کاملاً برداری شده، می‌توانیم دقیقاً همان کد را روی دستگاه‌های با شتاب‌دهنده سخت‌افزاری اجرا کنیم و بدون تلاش اضافی از تمام مزایا بهره‌مند شویم. - -علاوه بر این، JAX قبلاً می‌داند چگونه بسیاری از عملیات آرایه رایج را به طور مؤثر موازی کند، که کلید اجرای سریع است. - -برای اکثر موارد مواجه شده در اقتصاد، اقتصادسنجی و امور مالی، بسیار بهتر است که برای موازی‌سازی کارآمد به کامپایلر JAX تحویل دهیم تا اینکه سعی کنیم این روال‌ها را خودمان کدنویسی دستی کنیم. +```{note} +Numba می‌تواند برنامه‌نویسی GPU را از طریق `numba.cuda` پشتیبانی کند، اما در آن صورت باید موازی‌سازی را به صورت دستی انجام دهیم. برای اکثر موارد مواجه شده در اقتصاد، اقتصادسنجی و امور مالی، بسیار بهتر است که برای موازی‌سازی کارآمد به کامپایلر JAX تحویل دهیم تا اینکه سعی کنیم این روال‌ها را خودمان به صورت دستی کدنویسی کنیم. +``` ## عملیات ترتیبی @@ -530,8 +554,6 @@ with qe.Timer(): در حالی که سینتکس `at[t].set` در JAX به‌روزرسانی عنصر به عنصر را ممکن می‌سازد، کد کلی همچنان سخت‌تر از معادل Numba برای خواندن است. -برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیاده‌سازی است. - ## توصیه‌های کلی حال قدمی به عقب بر می‌داریم و مبادلات را خلاصه می‌کنیم. @@ -544,17 +566,12 @@ with qe.Timer(): علاوه بر این، توابع JAX به‌صورت خودکار مشتق‌پذیر هستند، همان‌طور که در {doc}`autodiff` بررسی می‌کنیم. -برای **عملیات ترتیبی**، Numba مزایای آشکاری دارد. +برای **عملیات ترتیبی**، Numba نحو بهتری دارد. کد طبیعی و خوانا است --- صرفاً یک حلقه پایتون با یک decorator --- و کارایی آن عالی است. JAX می‌تواند مسائل ترتیبی را از طریق `lax.fori_loop` یا `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است. -```{note} -یک مزیت مهم `lax.fori_loop` و `lax.scan` این است که از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کنند، که Numba قادر به انجام آن نیست. -اگر نیاز دارید از طریق یک محاسبه ترتیبی مشتق بگیرید (مثلاً محاسبه حساسیت‌های یک مسیر نسبت به پارامترهای مدل)، JAX علی‌رغم نحو کمتر طبیعی‌اش، انتخاب بهتری است. -``` - -در عمل، بسیاری از مسائل ترکیبی از هر دو الگو هستند. +از سوی دیگر، نسخه‌های JAX از مشتق‌گیری خودکار پشتیبانی می‌کنند. -یک قاعده سرانگشتی مناسب: برای پروژه‌های جدید، به‌ویژه زمانی که شتاب‌دهی سخت‌افزاری یا مشتق‌پذیری ممکن است مفید باشد، به‌طور پیش‌فرض از JAX استفاده کنید، و هنگامی که یک حلقه ترتیبی فشرده نیاز به سرعت و خوانایی دارد، به Numba متوسل شوید. +این ممکن است جالب توجه باشد اگر، برای مثال، بخواهیم حساسیت‌های یک مسیر را نسبت به پارامترهای مدل محاسبه کنیم.