TypeTrees for Autodiff
What are TypeTrees?
Memory layout descriptors for Enzyme autodiff backend. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.
Structure
TypeTree(Vec<Type>)
Type {
offset: isize, // byte offset (-1 = everywhere)
size: usize, // size in bytes
kind: Kind, // Float, Integer, Pointer, etc.
child: TypeTree // nested structure
}Example: fn compute(x: &f32, data: &[f32]) -> f32
Input 0: x: &f32
TypeTree(vec![Type {
offset: 0, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: 0, size: 4, kind: Float,
child: TypeTree::new()
}])
}])Input 1: data: &[f32]
TypeTree(vec![Type {
offset: 0, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float, // -1 = all elements
child: TypeTree::new()
}])
}])Output: f32
TypeTree(vec![Type {
offset: 0, size: 4, kind: Float,
child: TypeTree::new()
}])Why Needed?
- Enzyme can’t deduce complex type layouts from LLVM IR
- Prevents slow memory pattern analysis
- Enables correct derivative computation for nested structures
- Tells Enzyme which bytes are differentiable vs metadata
What Enzyme Does With This Information:
Without TypeTrees (current state):
; Enzyme sees generic LLVM IR:
define float @distance(i8* %p1, i8* %p2) {
; Has to guess what these pointers point to
; Slow analysis of all memory operations
; May miss optimization opportunities
}With TypeTrees (our goal):
// Enzyme knows:
// - %p1 points to struct with f32 at +0, f32 at +4, i32 at +8
// - Only the f32 fields need derivatives
// - Can generate efficient derivative code directlyTypeTrees - Offset and -1 Explained
Type Structure
Type {
offset: isize, // WHERE this type starts
size: usize, // HOW BIG this type is
kind: Kind, // WHAT KIND of data (Float, Int, Pointer)
child: TypeTree // WHAT'S INSIDE (for pointers/containers)
}
Offset Values
Regular Offset (0, 4, 8, etc.)
Specific byte position within a structure
struct Point {
x: f32, // offset 0, size 4
y: f32, // offset 4, size 4
id: i32, // offset 8, size 4
}
TypeTree for &Point:
TypeTree(vec![
Type { offset: 0, size: 4, kind: Float }, // x at byte 0
Type { offset: 4, size: 4, kind: Float }, // y at byte 4
Type { offset: 8, size: 4, kind: Integer } // id at byte 8
])
Offset -1 (Special: “Everywhere”)
Means “this pattern repeats for ALL elements”
Example 1: Array [f32; 100]
TypeTree(vec![Type {
offset: -1, // ALL positions
size: 4, // each f32 is 4 bytes
kind: Float, // every element is float
}])
Instead of listing 100 separate Types with offsets 0,4,8,12…396
Example 2: Slice &[i32]
// Pointer to slice data
TypeTree(vec![Type {
offset: 0, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, // ALL slice elements
size: 4, // each i32 is 4 bytes
kind: Integer
}])
}])
Example 3: Mixed Structure
struct Container {
header: i64, // offset 0
data: [f32; 1000], // offset 8, but elements use -1
}
TypeTree(vec![
Type { offset: 0, size: 8, kind: Integer }, // header
Type { offset: 8, size: 4000, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float // ALL array elements
}])
}
])Core TypeTree Generation Functions:
//compiler/rustc_middle/src/ty/mod.rs
pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
// Creates TypeTree from any Rust type
}
pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>, span: Option<Span>) -> FncTree {
// Handles function signatures, slice ABI adjustments, activity modifications
}
fn typetree_from_ty<'a>(...) -> TypeTree {
// Handles all Rust types: primitives, ADTs, arrays, slices, pointers
}Enums and unions are not handled
base.rs is where all autodiff functions get collected across the entire crate:
-
Functions marked with #[autodiff]
-
Their generated TypeTrees
-
Activity information
-
Target function names
What is FieldsShape?
FieldsShape is an enum from rustc_abi that describes how the fields of a type are laid out in memory. It’s crucial for your TypeTree generation because it tells you exactly where each field sits in memory.
pub enum FieldsShape {
/// All fields start at no offset. The `usize` is the field count.
/// Used for primitives and zero-sized types.
Primitive,
/// Array/vector layout: all elements have the same layout,
/// repeated `count` times with a `stride` between them.
Array { stride: Size, count: u64 },
/// Struct/tuple layout: arbitrary positioning of differently-sized fields.
/// The `offsets` vector tells you where each field starts.
Arbitrary {
offsets: &'tcx [Size],
memory_index: &'tcx [u32]
},
/// Union layout: all fields start at offset 0.
Union(NonZero<usize>),
}Todo:
in type tree for ABI matching logic can be abstracted out
like:
// Check if this type points to a slice
let points_to_slice = match ty.kind() {
ty::RawPtr(inner_ty, _) | ty::Ref(_, inner_ty, _) => inner_ty.is_slice(),
ty::Adt(adt_def, args_inner) if adt_def.is_box() => {
args_inner.type_at(0).is_slice()
}
_ => false,
};
if points_to_slice {
// Extract the slice element type
let slice_elem_ty = match ty.kind() {
ty::RawPtr(inner_ty, _) | ty::Ref(_, inner_ty, _) =>
inner_ty.builtin_index().unwrap(),
ty::Adt(_, args_inner) =>
args_inner.type_at(0).builtin_index().unwrap(),
_ => unreachable!(),
};
// Single slice handling logic
let child = typetree_from_ty(slice_elem_ty, tcx, 1, safety, &mut visited, span);
let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child };
args.push(TypeTree(vec![tt]));
let i64_tt = Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() };
args.push(TypeTree(vec![i64_tt]));
if !da.is_empty() {
let activity = match da[i] {
DiffActivity::DualOnly | DiffActivity::Dual |
DiffActivity::DuplicatedOnly | DiffActivity::Duplicated =>
DiffActivity::FakeActivitySize,
DiffActivity::Const => DiffActivity::Const,
_ => panic!("unexpected activity for ptr/ref"),
};
new_activities.push(activity);
new_positions.push(i + 1);
}
trace!("ABI MATCHING!");
continue;
}mansplaining(lol):
we are handling three different pointer types (*T, &T, Box
Why All Three Share the Same Slice Logic
Think of it like this:
What These Types Really Are:
-
&[f32] = “pointer to some f32s + length”
-
*mut [f32] = “pointer to some f32s + length”
-
Box<[f32]> = “pointer to some f32s + length”
At the memory level, they’re all the same thing!
What Happens at LLVM Level:
// In Rust you write:
fn process_slice(data: &[f32]) { ... }
fn process_ptr(data: *mut [f32]) { ... }
fn process_box(data: Box<[f32]>) { ... }// But LLVM sees all of them as:
define void @process_slice(float* %data_ptr, i64 %data_len)
define void @process_ptr(float* %data_ptr, i64 %data_len)
define void @process_box(float* %data_ptr, i64 %data_len)
/ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
// IDENTICAL SIGNATURES!The ABI (Application Binary Interface) Reality:
All slice types get “lowered” to the same thing:
-
Pointer to the data (float*)
-
Length as separate integer (i64)
Why TypeTree Cares:
Enzyme (the autodiff backend) needs to know:
-
“Parameter 1: pointer to array of floats”
-
“Parameter 2: integer length (not differentiable)”
It doesn’t matter if the original Rust code used &[f32], *mut [f32], or Box<[f32]> - they all become the same memory layout.
These become (ptr, len) in LLVM:
-
&[T] → (T*, usize)
-
&mut [T] → (T*, usize)
-
const [T] → (T, usize)
-
mut [T] → (T, usize)
-
Box<[T]> → (T*, usize)
-
&str → (u8*, usize)
-
&mut str → (u8*, usize)
-
const str → (u8, usize)
-
mut str → (u8, usize)
-
Box
→ (u8*, usize)
These have different ABI:
-
Vec
→ (T*, usize, usize) (ptr, len, capacity) -
String → (u8*, usize, usize) (ptr, len, capacity)