Surgeon
A surgeon transplanting weights from one model to another according to a SurgeryPlan.
The SurgeryPlan only specify where the shrinkage happens, when calling transplant, it will:
- use lesion_transplant to copy over the weights that requires shrinking.
- use copy_transplant on the remaining weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
surgery_plan |
SurgeryPlan
|
A surgery plan for the transplant. |
required |
Example
import tensorflow as tf
from connectionist.data import ToyOP
from connectionist.models import PMSP
from connectionist.surgery import SurgeryPlan, Surgeon, make_recipient
# Create model and train
data = ToyOP()
model = PMSP(tau=0.2, h_units=10, p_units=9, c_units=5)
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.BinaryCrossentropy(),
)
model.fit(data.x_train, data.y_train, epochs=10, batch_size=20)
# Create surgery plan and surgeon
plan = SurgeryPlan(layer='hidden', original_units=10, shrink_rate=0.3, make_model_fn=PMSP)
surgeon = Surgeon(surgery_plan=plan)
# Create recipient model and transplant weights
new_model = make_recipient(model=model, layer=plan.layer, keep_n=plan.keep_n, make_model_fn=plan.make_model_fn)
new_model.build(input_shape=model.pmsp._build_input_shape)
# Transplant weights and biases
surgeon.transplant(donor=model, recipient=new_model)
Also see connectionist.models.PMSP.shrink_layer for high-level API.
lesion_transplant(donor, recipient, weight_name, keep_idx, axes)
Transplant weights from donor to recipient in the weights that requires shrinking.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
donor |
tf.keras.Model
|
The donor model. |
required |
recipient |
tf.keras.Model
|
The recipient model. |
required |
weight_name |
str
|
The name of the weights to transplant. |
required |
keep_idx |
List[int]
|
The indices to keep. |
required |
axes |
List[int]
|
The axes to slice the weights, usually contains only 1 axis, but 2 when it is a self-connecting weight. |
required |
transplant(donor, recipient)
Transplant all weights from donor to recipient model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
donor |
tf.keras.Model
|
The donor model. |
required |
recipient |
tf.keras.Model
|
The recipient model. |
required |