# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( pretrained_model_name_or_path, resolved_archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, token=token, user_agent=user_agent, revision=revision, subfolder=subfolder, _commit_hash=commit_hash, )
这里执行完可以把这两个变量都打印出来看:
1 2 3 4 5 6 7 8 9 10
(Pdb) p resolved_archive_file ['/home/dell/sdb/.cache/Meta-Llama-3-8B-Instruct/model-00001-of-00004.safetensors', '/home/dell/sdb/.cache/Meta-Llama-3-8B-Instruct/model-00002-of-00004.safetensors', '/home/dell/sdb/.cache/Meta-Llama-3-8B-Instruct/model-00003-of-00004.safetensors', '/home/dell/sdb/.cache/Meta-Llama-3-8B-Instruct/model-00004-of-00004.safetensors']
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts with ContextManagers(init_contexts): # Let's make sure we don't run the init functionof buffer modules model = cls(config, *model_args, **model_kwargs)
for shard_file in resolved_archive_file: # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. if shard_file in disk_only_shard_files: continue state_dict = load_state_dict(shard_file, is_quantized=is_quantized)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys += _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, add_prefix_to_model, remove_prefix_from_model, ignore_mismatched_sizes, ) if low_cpu_mem_usage: if is_fsdp_enabled() andnot is_local_dist_rank_0() andnot is_quantized: for key, param in model_to_load.state_dict().items(): if param.device == torch.device("meta"): set_module_tensor_to_device( model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) ) else: new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, state_dict, loaded_keys, start_prefix, expected_keys, device_map=device_map, offload_folder=offload_folder, offload_index=offload_index, state_dict_folder=state_dict_folder, state_dict_index=state_dict_index, dtype=dtype, hf_quantizer=hf_quantizer, is_safetensors=is_safetensors, keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, ) error_msgs += new_error_msgs else: # Sharded checkpoint or whole but low_cpu_mem_usage==True if assign_to_params_buffers is None: assign_to_params_buffers = check_support_param_buffer_assignment( model_to_load, state_dict, start_prefix ) error_msgs += _load_state_dict_into_model( model_to_load, state_dict, start_prefix, assign_to_params_buffers )
# force memory release del state_dict gc.collect()