TypeTrees for Autodiff

What are TypeTrees?

Memory layout descriptors for Enzyme autodiff backend. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.

Structure

TypeTree(Vec<Type>)
 
Type {
    offset: isize,  // byte offset (-1 = everywhere)
    size: usize,    // size in bytes
    kind: Kind,     // Float, Integer, Pointer, etc.
    child: TypeTree // nested structure
}

Example: fn compute(x: &f32, data: &[f32]) -> f32

Input 0: x: &f32

TypeTree(vec![Type {
    offset: 0, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: 0, size: 4, kind: Float,
        child: TypeTree::new()
    }])
}])

Input 1: data: &[f32]

TypeTree(vec![Type {
    offset: 0, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, size: 4, kind: Float,  // -1 = all elements
        child: TypeTree::new()
    }])
}])

Output: f32

TypeTree(vec![Type {
    offset: 0, size: 4, kind: Float,
    child: TypeTree::new()
}])

Why Needed?

  • Enzyme can’t deduce complex type layouts from LLVM IR
  • Prevents slow memory pattern analysis
  • Enables correct derivative computation for nested structures
  • Tells Enzyme which bytes are differentiable vs metadata

What Enzyme Does With This Information:

Without TypeTrees (current state):

; Enzyme sees generic LLVM IR:
 
define float @distance(i8* %p1, i8* %p2) {
 
; Has to guess what these pointers point to
 
; Slow analysis of all memory operations
 
; May miss optimization opportunities
 
}

With TypeTrees (our goal):

// Enzyme knows:
 
// - %p1 points to struct with f32 at +0, f32 at +4, i32 at +8
 
// - Only the f32 fields need derivatives
 
// - Can generate efficient derivative code directly

TypeTrees - Offset and -1 Explained

Type Structure

 
Type {
 
offset: isize, // WHERE this type starts
 
size: usize, // HOW BIG this type is
 
kind: Kind, // WHAT KIND of data (Float, Int, Pointer)
 
child: TypeTree // WHAT'S INSIDE (for pointers/containers)
 
}
 

Offset Values

Regular Offset (0, 4, 8, etc.)

Specific byte position within a structure

 
struct Point {
 
x: f32, // offset 0, size 4
 
y: f32, // offset 4, size 4
 
id: i32, // offset 8, size 4
 
}
 

TypeTree for &Point:

 
TypeTree(vec![
 
	Type { offset: 0, size: 4, kind: Float }, // x at byte 0
 
	Type { offset: 4, size: 4, kind: Float }, // y at byte 4
 
	Type { offset: 8, size: 4, kind: Integer } // id at byte 8
 
])
 

Offset -1 (Special: “Everywhere”)

Means “this pattern repeats for ALL elements”

Example 1: Array [f32; 100]

 
TypeTree(vec![Type {
 
offset: -1, // ALL positions
 
size: 4, // each f32 is 4 bytes
 
kind: Float, // every element is float
 
}])
 

Instead of listing 100 separate Types with offsets 0,4,8,12…396

Example 2: Slice &[i32]

 
// Pointer to slice data
 
TypeTree(vec![Type {
 
	offset: 0, size: 8, kind: Pointer,
 
	child: TypeTree(vec![Type {
 
	offset: -1, // ALL slice elements
 
	size: 4, // each i32 is 4 bytes
 
	kind: Integer
 
	}])
 
}])
 

Example 3: Mixed Structure

 
struct Container {
 
	header: i64, // offset 0
 
	data: [f32; 1000], // offset 8, but elements use -1
 
}
 
 
TypeTree(vec![
 
	Type { offset: 0, size: 8, kind: Integer }, // header
 
	Type { offset: 8, size: 4000, kind: Pointer,
 
	child: TypeTree(vec![Type {
 
	offset: -1, size: 4, kind: Float // ALL array elements
 
}])
 
}
 
])

Core TypeTree Generation Functions:

//compiler/rustc_middle/src/ty/mod.rs
 
pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
    // Creates TypeTree from any Rust type
}
 
 
pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>, span: Option<Span>) -> FncTree {
    // Handles function signatures, slice ABI adjustments, activity modifications
}
 
 
fn typetree_from_ty<'a>(...) -> TypeTree {
    // Handles all Rust types: primitives, ADTs, arrays, slices, pointers
}

Enums and unions are not handled

base.rs is where all autodiff functions get collected across the entire crate:

  • Functions marked with #[autodiff]

  • Their generated TypeTrees

  • Activity information

  • Target function names

What is FieldsShape?

FieldsShape is an enum from rustc_abi that describes how the fields of a type are laid out in memory. It’s crucial for your TypeTree generation because it tells you exactly where each field sits in memory.

pub enum FieldsShape {
    /// All fields start at no offset. The `usize` is the field count.
    /// Used for primitives and zero-sized types.
    Primitive,
    
    /// Array/vector layout: all elements have the same layout,
    /// repeated `count` times with a `stride` between them.
    Array { stride: Size, count: u64 },
    
    /// Struct/tuple layout: arbitrary positioning of differently-sized fields.
    /// The `offsets` vector tells you where each field starts.
    Arbitrary { 
        offsets: &'tcx [Size], 
        memory_index: &'tcx [u32] 
    },
    
    /// Union layout: all fields start at offset 0.
    Union(NonZero<usize>),
}

Todo:

in type tree for ABI matching logic can be abstracted out

like:

// Check if this type points to a slice
let points_to_slice = match ty.kind() {
    ty::RawPtr(inner_ty, _) | ty::Ref(_, inner_ty, _) => inner_ty.is_slice(),
    ty::Adt(adt_def, args_inner) if adt_def.is_box() => {
        args_inner.type_at(0).is_slice()
    }
    _ => false,
};
 
if points_to_slice {
    // Extract the slice element type
    let slice_elem_ty = match ty.kind() {
        ty::RawPtr(inner_ty, _) | ty::Ref(_, inner_ty, _) => 
            inner_ty.builtin_index().unwrap(),
        ty::Adt(_, args_inner) => 
            args_inner.type_at(0).builtin_index().unwrap(),
        _ => unreachable!(),
    };
    
    // Single slice handling logic
    let child = typetree_from_ty(slice_elem_ty, tcx, 1, safety, &mut visited, span);
    let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child };
    args.push(TypeTree(vec![tt]));
    let i64_tt = Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() };
    args.push(TypeTree(vec![i64_tt]));
    
    if !da.is_empty() {
        let activity = match da[i] {
            DiffActivity::DualOnly | DiffActivity::Dual | 
            DiffActivity::DuplicatedOnly | DiffActivity::Duplicated => 
                DiffActivity::FakeActivitySize,
            DiffActivity::Const => DiffActivity::Const,
            _ => panic!("unexpected activity for ptr/ref"),
        };
        new_activities.push(activity);
        new_positions.push(i + 1);
    }
    trace!("ABI MATCHING!");
    continue;
}

mansplaining(lol):

we are handling three different pointer types (*T, &T, Box) but they all need the same slice ABI logic.

Why All Three Share the Same Slice Logic

Think of it like this:

What These Types Really Are:

  • &[f32] = “pointer to some f32s + length”

  • *mut [f32] = “pointer to some f32s + length”

  • Box<[f32]> = “pointer to some f32s + length”

At the memory level, they’re all the same thing!

What Happens at LLVM Level:

// In Rust you write:
fn process_slice(data: &[f32]) { ... }
fn process_ptr(data: *mut [f32]) { ... }  
fn process_box(data: Box<[f32]>) { ... }
// But LLVM sees all of them as:
 
define void @process_slice(float* %data_ptr, i64 %data_len)
 
define void @process_ptr(float* %data_ptr, i64 %data_len)
 
define void @process_box(float* %data_ptr, i64 %data_len)
 
						/ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
						// IDENTICAL SIGNATURES!

The ABI (Application Binary Interface) Reality:

All slice types get “lowered” to the same thing:

  1. Pointer to the data (float*)

  2. Length as separate integer (i64)

Why TypeTree Cares:

Enzyme (the autodiff backend) needs to know:

  • “Parameter 1: pointer to array of floats”

  • “Parameter 2: integer length (not differentiable)”

It doesn’t matter if the original Rust code used &[f32], *mut [f32], or Box<[f32]> - they all become the same memory layout.

These become (ptr, len) in LLVM:

  • &[T] → (T*, usize)

  • &mut [T] → (T*, usize)

  • const [T] → (T, usize)

  • mut [T] → (T, usize)

  • Box<[T]> → (T*, usize)

  • &str → (u8*, usize)

  • &mut str → (u8*, usize)

  • const str → (u8, usize)

  • mut str → (u8, usize)

  • Box → (u8*, usize)

These have different ABI:

  • Vec → (T*, usize, usize) (ptr, len, capacity)

  • String → (u8*, usize, usize) (ptr, len, capacity)