def get_param_size(model): params = 0 for p in model.parameters(): tmp = 1 for x in p.size(): tmp *= x params += tmp return params