Skip to content

Surgeon

Source code

A surgeon transplanting weights from one model to another according to a SurgeryPlan.

The SurgeryPlan only specify where the shrinkage happens, when calling transplant, it will:

Parameters:

Name Type Description Default
surgery_plan SurgeryPlan

A surgery plan for the transplant.

required

Example

import tensorflow as tf
from connectionist.data import ToyOP
from connectionist.models import PMSP
from connectionist.surgery import SurgeryPlan, Surgeon, make_recipient

# Create model and train
data = ToyOP()
model = PMSP(tau=0.2, h_units=10, p_units=9, c_units=5)

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.BinaryCrossentropy(),
)
model.fit(data.x_train, data.y_train, epochs=10, batch_size=20)

# Create surgery plan and surgeon
plan = SurgeryPlan(layer='hidden', original_units=10, shrink_rate=0.3, make_model_fn=PMSP)
surgeon = Surgeon(surgery_plan=plan)

# Create recipient model and transplant weights
new_model = make_recipient(model=model, layer=plan.layer, keep_n=plan.keep_n, make_model_fn=plan.make_model_fn)
new_model.build(input_shape=model.pmsp._build_input_shape)

# Transplant weights and biases
surgeon.transplant(donor=model, recipient=new_model)

Also see connectionist.models.PMSP.shrink_layer for high-level API.

lesion_transplant(donor, recipient, weight_name, keep_idx, axes)

Transplant weights from donor to recipient in the weights that requires shrinking.

Parameters:

Name Type Description Default
donor tf.keras.Model

The donor model.

required
recipient tf.keras.Model

The recipient model.

required
weight_name str

The name of the weights to transplant.

required
keep_idx List[int]

The indices to keep.

required
axes List[int]

The axes to slice the weights, usually contains only 1 axis, but 2 when it is a self-connecting weight.

required

transplant(donor, recipient)

Transplant all weights from donor to recipient model.

Parameters:

Name Type Description Default
donor tf.keras.Model

The donor model.

required
recipient tf.keras.Model

The recipient model.

required