sys_sampling.py 772 Bytes
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm

categories = ['NWRW19','NPRW19','NLRW19','NIRW19']
for category in categories:
    table=pq.read_table(f'categorized_parquet/category={category}')
    labels=[]
    index=0
    last_topic=''
    for topic in tqdm(table['topic']):
        if topic!=last_topic:
            index=0
            last_topic=topic
        mod=index %100
        if mod==49:
            labels.append('valid')
        elif mod==99:
            labels.append('test')
        else:
            labels.append('train')
        index+=1
    pq.write_to_dataset( table.append_column('label',pa.array(labels)), root_path='dataset',
                    partition_cols=['topic', 'label'],coerce_timestamps='us')