Deprecate reserved_label_len arg
This commit is contained in:
hiyouga
2024-07-01 01:19:27 +08:00
parent d4e2af1fa4
commit 1771251ce3
13 changed files with 329 additions and 223 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
from datasets import concatenate_datasets, interleave_datasets
@@ -30,6 +30,9 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@unique
class Role(str, Enum):
USER = "user"
@@ -39,13 +42,6 @@ class Role(str, Enum):
OBSERVATION = "observation"
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - min(max_target_len, target_len)
return max_source_len, max_target_len
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",