BaseModel
BaseModel
- class models.BaseModel.BaseModel(*args: Any, **kwargs: Any)
Base class for all models.
- __init__(config: utils.util.Config)
- Parameters
config – the configuration object intialized by
utils.manager.Manager.setup()name – the name of the model
- config
- index_dir
the folder to save index e.g.
utils.index.FaissIndexandutils.index.InvertedVectorIndex- Type
- collection_dir
the folder to save json collections returned by e.g.
utils.index.AnseriniBM25Index.fit()- Type
- encode_dir
the folder to save text encoding memmap file returned by
models.BaseModel.BaseSparseModel.encode_text()- Type
- query_dir
the folder to save query encoding memmap file returned by
models.BaseModel.BaseSparseModel.encode_query()- Type
- retrieval_result_path
the path of the final retrieval result file returned by
models.BaseModel.BaseModel.retrieve()- Type
- logger
the logger
- Type
MasterLoger
- _move_to_device(data, exclude_keys=['text_idx', 'query_idx'])
Move data to device.
- Parameters
exclude_keys – variables that should be kept unchanged
- _l2_distance(x1: torch.Tensor, x2: torch.Tensor) torch.Tensor
Compute l2 similarity.
- Parameters
x1 – tensor of [B, D]
x2 – tensor of [B, D]
- _cos_sim(x1: torch.Tensor, x2: torch.Tensor, temperature: float = 0.1) torch.Tensor
Compute cosine similarity.
- Parameters
x1 – tensor of [B, D]
x2 – tensor of [B, D]
temperature – scale the similarity scores by dividing temperature
- _compute_teacher_score(x)
Compute teacher score in knowledge distillation; return None if training in contrastive mode.
- _compute_loss(score: torch.Tensor, label: torch.Tensor, teacher_score: Optional[torch.Tensor] = None)
A general method to compute loss (contrastive cross-entropy or distillation)
- create_optimizer() torch.optim.Optimizer
Create optimizer; Subclass may override this function to create custom optimizers. Return None to use the default optimizer created by Trainer.
- Returns
optimizer
- _gather_objects(local_object: object) list[object]
Gather common python objects across processes.
Note
This function implicitly consumes GPU.
- Parameters
local_object – python object to collect
- _gather_tensors(local_tensor: torch.Tensor) torch.Tensor
Gather tensors from all gpus on each process.
- Parameters
local_tensor – the tensor that needs to be gathered
- Returns
concatenation of local_tensor in each process
- save_to_mmp(path: str, shape: tuple, dtype: numpy.dtype, loader: torch.utils.data.DataLoader, obj: numpy.ndarray, batch_size: int = 1000)
Create a
np.memmapfile ofshapewithdtype;Create lock;
Save the
objto the offsetutils.util.Sequential_Sampler.start;Release lock.
- Parameters
path – the memmap file path
shape – the shape of the memmap file to be created
dtype –
loader – the dataloader for the data
obj – the array to be stored
batch_size – saving in batch
- gather_retrieval_result(retrieval_result: Union[dict[int, list[int]], dict[int, list[tuple[int, float]]]], hits: Optional[int] = None, retrieval_result_path: Optional[str] = None) Union[dict[int, list[int]], dict[int, list[tuple[int, float]]]]
Gather
retrieval_resultacross processes;Returning the reordered result cut off to top k;
Create a lock;
Save the result at
models.BaseModel.BaseModel.retrieval_result_path.Release the lock.
- Parameters
retrieval_result – each tuple is a document id-score pair
- Returns
processed retrieval result
- init_verifier(loaders: dict[str, torch.utils.data.DataLoader], load_all_verifier: bool = False)
Initialize post verifier defined in :pyobj:
utils.index.VERIFIER_MAP.- Parameters
loaders –
load_all_verifier – if
True, load all the verifier embeddings/codes
- encode(loaders)
Shotcut for encoding both text and query.
- index(loaders: dict[str, torch.utils.data.DataLoader])
The index method. Subclass should override this function.
- retrieve(loaders: dict[str, torch.utils.data.DataLoader])
The retrieve method. Subclass should override this function.
- log_result(**kwargs)
Save the model metrics and configurations in
performance.log.
- load()
Load the model with current config from
config.load_ckpt.
- step_end_callback(loaders, state)
Callback at the end of each training step.
BaseSparseModel
- class models.BaseModel.BaseSparseModel(*args: Any, **kwargs: Any)
Base class for all models that rely on token weights to rank documents.
- __init__(config: utils.util.Config)
- Parameters
config – the configuration object intialized by
utils.manager.Manager.setup()name – the name of the model
- config
- index_dir
the folder to save index e.g.
utils.index.FaissIndexandutils.index.InvertedVectorIndex- Type
- collection_dir
the folder to save json collections returned by e.g.
utils.index.AnseriniBM25Index.fit()- Type
- encode_dir
the folder to save text encoding memmap file returned by
models.BaseModel.BaseSparseModel.encode_text()- Type
- query_dir
the folder to save query encoding memmap file returned by
models.BaseModel.BaseSparseModel.encode_query()- Type
- retrieval_result_path
the path of the final retrieval result file returned by
models.BaseModel.BaseModel.retrieve()- Type
- logger
the logger
- Type
MasterLoger
- _compute_overlap(query_token_id: torch.Tensor, text_token_id: torch.Tensor) torch.Tensor
Compute overlapping mask between the query tokens and positive sequence tokens across batches.
- Parameters
query_token_id – [B1, LQ]
text_token_id – [B2, LS]
- Returns
[B, LQ, B, LS] if cross_batch, else [B, LQ, LS]
- Return type
overlapping_mask
- _gate_text(text_token_weights: numpy.ndarray, k: Optional[int] = None)
Gate the text token weights so that only the top
config.query_gate_ktokens are valid. Keep the text_token_ids because we will use it to construct the entire inverted lists.- Parameters
query_embeddings – [N, L, 1]
- encode_text_step(x)
One step in encode_text.
- Parameters
x – a data record.
- Returns
the text token id for indexing, array of [B, L] the text token embedding for indexing, array of [B, L, D]
- encode_query_step(x)
One step in encode_text.
- Parameters
x – a data record.
- Returns
the query token id for searching, array of [B, L] the query token embedding for indexing, array of [B, L, D]
- inverted_index(loader_text: torch.utils.data.DataLoader)
Construct
utils.index.BaseInvertedIndex.
- anserini_index(loader_text: torch.utils.data.DataLoader)
Construct
utils.index.BaseAnseriniIndex.
- index(**kwargs)
The index method. Subclass should override this function.
- retrieve(**kwargs)
The retrieve method. Subclass should override this function.
BaseDenseModel
- class models.BaseModel.BaseDenseModel(*args: Any, **kwargs: Any)
Base class for all models that rely on sequence embeddings to rank documents.
- __init__(config)
- Parameters
config – the configuration object intialized by
utils.manager.Manager.setup()name – the name of the model
- config
- index_dir
the folder to save index e.g.
utils.index.FaissIndexandutils.index.InvertedVectorIndex- Type
- collection_dir
the folder to save json collections returned by e.g.
utils.index.AnseriniBM25Index.fit()- Type
- encode_dir
the folder to save text encoding memmap file returned by
models.BaseModel.BaseSparseModel.encode_text()- Type
- query_dir
the folder to save query encoding memmap file returned by
models.BaseModel.BaseSparseModel.encode_query()- Type
- retrieval_result_path
the path of the final retrieval result file returned by
models.BaseModel.BaseModel.retrieve()- Type
- logger
the logger
- Type
MasterLoger
- faiss_index(loader_text: torch.utils.data.DataLoader)
Construct
utils.index.FaissIndex
- index(**kwargs)
The index method. Subclass should override this function.
- retrieve(**kwargs)
The retrieve method. Subclass should override this function.
BaseGenerativeModel
- class models.BaseModel.BaseGenerativeModel(*args: Any, **kwargs: Any)
Base class for generative models e.g. DSI, WebUltron.
- __init__(config: utils.util.Config)
- Parameters
config – the configuration object intialized by
utils.manager.Manager.setup()name – the name of the model
- config
- index_dir
the folder to save index e.g.
utils.index.FaissIndexandutils.index.InvertedVectorIndex- Type
- collection_dir
the folder to save json collections returned by e.g.
utils.index.AnseriniBM25Index.fit()- Type
- encode_dir
the folder to save text encoding memmap file returned by
models.BaseModel.BaseSparseModel.encode_text()- Type
- query_dir
the folder to save query encoding memmap file returned by
models.BaseModel.BaseSparseModel.encode_query()- Type
- retrieval_result_path
the path of the final retrieval result file returned by
models.BaseModel.BaseModel.retrieve()- Type
- logger
the logger
- Type
MasterLoger
- generative_index(loader_text: torch.utils.data.DataLoader)
Construct
utils.index.TrieIndex.
- index(**kwargs)
The index method. Subclass should override this function.
- retrieve(**kwargs)
The retrieve method. Subclass should override this function.