sys_sampling.py
772 Bytes
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm
categories = ['NWRW19','NPRW19','NLRW19','NIRW19']
for category in categories:
table=pq.read_table(f'categorized_parquet/category={category}')
labels=[]
index=0
last_topic=''
for topic in tqdm(table['topic']):
if topic!=last_topic:
index=0
last_topic=topic
mod=index %100
if mod==49:
labels.append('valid')
elif mod==99:
labels.append('test')
else:
labels.append('train')
index+=1
pq.write_to_dataset( table.append_column('label',pa.array(labels)), root_path='dataset',
partition_cols=['topic', 'label'],coerce_timestamps='us')