diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh
index 31f42fdbd8..c646b8f9a8 100755
--- a/.ci/docker/build.sh
+++ b/.ci/docker/build.sh
@@ -11,8 +11,9 @@ IMAGE_NAME="$1"
 shift
 
 export UBUNTU_VERSION="20.04"
+export CUDA_VERSION="12.4.1"
 
-export BASE_IMAGE="ubuntu:${UBUNTU_VERSION}"
+export BASE_IMAGE="nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}"
 echo "Building ${IMAGE_NAME} Docker image"
 
 docker build \
diff --git a/.ci/docker/common/common_utils.sh b/.ci/docker/common/common_utils.sh
index b20286a409..c7eabda555 100644
--- a/.ci/docker/common/common_utils.sh
+++ b/.ci/docker/common/common_utils.sh
@@ -22,5 +22,5 @@ conda_run() {
 }
 
 pip_install() {
-  as_ci_user conda run -n py_$ANACONDA_PYTHON_VERSION pip install --progress-bar off $*
+  as_ci_user conda run -n py_$ANACONDA_PYTHON_VERSION pip3 install --progress-bar off $*
 }
diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt
index 00cf2f2103..9668b17fc3 100644
--- a/.ci/docker/requirements.txt
+++ b/.ci/docker/requirements.txt
@@ -30,8 +30,8 @@ pytorch-lightning
 torchx
 torchrl==0.5.0
 tensordict==0.5.0
-ax-platform>==0.4.0
-nbformat>==5.9.2
+ax-platform>=0.4.0
+nbformat>=5.9.2
 datasets
 transformers
 torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable
@@ -68,4 +68,4 @@ pygame==2.1.2
 pycocotools
 semilearn==0.3.2
 torchao==0.0.3
-segment_anything==1.0
\ No newline at end of file
+segment_anything==1.0
diff --git a/.jenkins/metadata.json b/.jenkins/metadata.json
index 4814f9a7d2..2f1a9933aa 100644
--- a/.jenkins/metadata.json
+++ b/.jenkins/metadata.json
@@ -28,6 +28,9 @@
   "intermediate_source/model_parallel_tutorial.py": {
     "needs": "linux.16xlarge.nvidia.gpu"
   },
+  "recipes_source/torch_export_aoti_python.py": {
+    "needs": "linux.g5.4xlarge.nvidia.gpu"
+  }, 
   "advanced_source/pendulum.py": {
     "needs": "linux.g5.4xlarge.nvidia.gpu",
     "_comment": "need to be here for the compiling_optimizer_lr_scheduler.py to run."
diff --git a/en-wordlist.txt b/en-wordlist.txt
index 62762ab69c..e69cbaa1a5 100644
--- a/en-wordlist.txt
+++ b/en-wordlist.txt
@@ -2,6 +2,7 @@
 ACL
 ADI
 AOT
+AOTInductor
 APIs
 ATen
 AVX
@@ -617,4 +618,4 @@ warmstarting
 warmup
 webp
 wsi
-wsis
\ No newline at end of file
+wsis
diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst
index d94d7d5c22..caccdcc28f 100644
--- a/recipes_source/recipes_index.rst
+++ b/recipes_source/recipes_index.rst
@@ -150,6 +150,12 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
    :link: ../recipes/recipes/swap_tensors.html
    :tags: Basics
 
+.. customcarditem::
+   :header: torch.export AOTInductor Tutorial for Python runtime
+   :card_description: Learn an end-to-end example of how to use AOTInductor for python runtime.
+   :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
+   :link: ../recipes/torch_export_aoti_python.html
+   :tags: Basics
 
 .. Interpretability
 
diff --git a/recipes_source/torch_export_aoti_python.py b/recipes_source/torch_export_aoti_python.py
new file mode 100644
index 0000000000..136862078c
--- /dev/null
+++ b/recipes_source/torch_export_aoti_python.py
@@ -0,0 +1,220 @@
+# -*- coding: utf-8 -*-
+
+"""
+(Beta) ``torch.export`` AOTInductor Tutorial for Python runtime
+===============================================================
+**Author:** Ankith Gunapal, Bin Bao, Angela Yi
+"""
+
+######################################################################
+#
+# .. warning::
+#
+#     ``torch._inductor.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility
+#     breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.
+#
+# It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used 
+# to do Ahead-of-Time compilation of PyTorch exported models by creating
+# a shared library that can be run in a non-Python environment.
+#
+#
+# In this tutorial, you will learn an end-to-end example of how to use AOTInductor for python runtime.
+# We will look at how  to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a 
+# shared library. Additionally, we will examine how to execute the shared library in Python runtime using :func:`torch._export.aot_load`.
+# You will learn about the speed up seen in the first inference time using AOTInductor, especially when using 
+# ``max-autotune`` mode which can take some time to execute.
+#
+# **Contents**
+#
+# .. contents::
+#     :local:
+
+######################################################################
+# Prerequisites
+# -------------
+# * PyTorch 2.4 or later
+# * Basic understanding of ``torch.export`` and AOTInductor
+# * Complete the `AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`_ tutorial
+
+######################################################################
+# What you will learn
+# ----------------------
+# * How to use AOTInductor for python runtime.
+# * How  to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a shared library
+# * How to run a shared library in Python runtime using :func:`torch._export.aot_load`.
+# * When do you use AOTInductor for python runtime
+
+######################################################################
+# Model Compilation
+# -----------------
+#
+# We will use the TorchVision pretrained `ResNet18` model and TorchInductor on the 
+# exported PyTorch program using :func:`torch._inductor.aot_compile`.
+#
+# .. note::
+#
+#       This API also supports :func:`torch.compile` options like ``mode``
+#       This means that if used on a CUDA enabled device, you can, for example, set ``"max_autotune": True``
+#       which leverages Triton based matrix multiplications & convolutions, and enables CUDA graphs by default.
+#
+# We also specify ``dynamic_shapes`` for the batch dimension. In this example, ``min=2`` is not a bug and is 
+# explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__
+
+
+import os
+import torch
+from torchvision.models import ResNet18_Weights, resnet18
+
+model = resnet18(weights=ResNet18_Weights.DEFAULT)
+model.eval()
+
+with torch.inference_mode():
+
+    # Specify the generated shared library path
+    aot_compile_options = {
+            "aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
+    }
+    if torch.cuda.is_available():
+        device = "cuda"
+        aot_compile_options.update({"max_autotune": True})
+    else:
+        device = "cpu"
+
+    model = model.to(device=device)
+    example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
+
+    # min=2 is not a bug and is explained in the 0/1 Specialization Problem
+    batch_dim = torch.export.Dim("batch", min=2, max=32)
+    exported_program = torch.export.export(
+        model,
+        example_inputs,
+        # Specify the first dimension of the input x as dynamic
+        dynamic_shapes={"x": {0: batch_dim}},
+    )
+    so_path = torch._inductor.aot_compile(
+        exported_program.module(),
+        example_inputs,
+        # Specify the generated shared library path
+        options=aot_compile_options
+    )
+
+
+######################################################################
+# Model Inference in Python
+# -------------------------
+#
+# Typically, the shared object generated above is used in a non-Python environment. In PyTorch 2.3, 
+# we added a new API called :func:`torch._export.aot_load` to load the shared library in the Python runtime.
+# The API follows a structure similar to the :func:`torch.jit.load` API . You need to specify the path 
+# of the shared library and the device where it should be loaded.
+#
+# .. note::
+#      In the example above, we specified ``batch_size=1`` for inference and  it still functions correctly even though we specified ``min=2`` in 
+#      :func:`torch.export.export`.
+
+
+import os
+import torch
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so")
+
+model = torch._export.aot_load(model_so_path, device)
+example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
+
+with torch.inference_mode():
+    output = model(example_inputs)
+
+######################################################################
+# When to use AOTInductor for Python Runtime
+# ------------------------------------------
+#
+# One of the requirements for using AOTInductor is that the model shouldn't have any graph breaks.
+# Once this requirement is met, the primary use case for using AOTInductor Python Runtime is for
+# model deployment using Python.
+# There are mainly two reasons why you would use AOTInductor Python Runtime:
+#
+# -  ``torch._inductor.aot_compile`` generates a shared library. This is useful for model
+#    versioning for deployments and tracking model performance over time.
+# -  With :func:`torch.compile` being a JIT compiler, there is a warmup
+#    cost associated with the first compilation. Your deployment needs to account for the
+#    compilation time taken for the first inference. With AOTInductor, the compilation is
+#    done offline using ``torch.export.export`` & ``torch._indutor.aot_compile``. The deployment
+#    would only load the shared library using ``torch._export.aot_load`` and run inference.
+#
+#
+# The section below shows the speedup achieved with AOTInductor for first inference
+#
+# We define a utility function ``timed`` to measure the time taken for inference
+#
+
+import time
+def timed(fn):
+    # Returns the result of running `fn()` and the time it took for `fn()` to run,
+    # in seconds. We use CUDA events and synchronization for accurate
+    # measurement on CUDA enabled devices.
+    if torch.cuda.is_available():
+        start = torch.cuda.Event(enable_timing=True)
+        end = torch.cuda.Event(enable_timing=True)
+        start.record()
+    else:
+        start = time.time()
+
+    result = fn()
+    if torch.cuda.is_available():
+        end.record()
+        torch.cuda.synchronize()
+    else:
+        end = time.time()
+
+    # Measure time taken to execute the function in miliseconds
+    if torch.cuda.is_available():
+        duration = start.elapsed_time(end)
+    else:
+        duration = (end - start) * 1000
+
+    return result, duration
+
+
+######################################################################
+# Lets measure the time for first inference using AOTInductor
+
+torch._dynamo.reset()
+
+model = torch._export.aot_load(model_so_path, device)
+example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
+
+with torch.inference_mode():
+    _, time_taken = timed(lambda: model(example_inputs))
+    print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms")
+
+
+######################################################################
+# Lets measure the time for first inference using ``torch.compile``
+
+torch._dynamo.reset()
+
+model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
+model.eval()
+
+model = torch.compile(model)
+example_inputs = torch.randn(1, 3, 224, 224, device=device)
+
+with torch.inference_mode():
+    _, time_taken = timed(lambda: model(example_inputs))
+    print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms")
+
+######################################################################
+# We see that there is a drastic speedup in first inference time using AOTInductor compared
+# to ``torch.compile``
+
+######################################################################
+# Conclusion
+# ----------
+#
+# In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by 
+# compiling and loading a pretrained ``ResNet18`` model using the ``torch._inductor.aot_compile``
+# and ``torch._export.aot_load`` APIs. This process demonstrates the practical application of 
+# generating a shared library and running it within a Python environment, even with dynamic shape
+# considerations and device-specific optimizations. We also looked at the advantage of using 
+# AOTInductor in model deployments, with regards to speed up in first inference time.