-
Notifications
You must be signed in to change notification settings - Fork 32
Open
Description
Hi author,
Right now I download the pretrain weight from Multiview Contrast. I want to load this weight, and then pass a pdb file to the model. In this case, I can get the latent representation of the protein pdb file from the model.
Now I think I can successfully load the released pretrained weight, however, I am stuck at passing the pdb file to the model. Here is how I pass it:
def extract_representation(model, protein_structure_path):
protein = Protein.from_pdb(protein_structure_path)
_protein = Protein.pack([protein])
input_feature = protein.atom2graph
with torch.no_grad():
representation = model(_protein, input_feature)
# Return the output, which is the protein's representation
return representationI think GearNet model takes two parameter: 1. a Graph object, 2. input (not sure what this is). However, it seems like either _protein is wrong or input_feature is wrong. With this code, I am getting this error:
Traceback (most recent call last):
File "/content/GearNet-main/script/test.py", line 62, in <module>
main()
File "/content/GearNet-main/script/test.py", line 58, in main
representation = extract_representation(model, args.pdb)
File "/content/GearNet-main/script/test.py", line 43, in extract_representation
representation = model(_protein, input_feature)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/content/GearNet-main/gearnet/model.py", line 115, in forward
edge_hidden = self.edge_layers[i](line_graph, edge_hidden)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/content/GearNet-main/gearnet/layer.py", line 132, in forward
update = self.aggregate(graph, message)
File "/content/GearNet-main/gearnet/layer.py", line 118, in aggregate
update = update.view(graph.num_node, self.num_relation * self.input_dim)
RuntimeError: shape '[4944, 472]' is invalid for input of size 711936Can you please help me on how to pass the pdb file to the model.
Thank you
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels