def prepare_dataset(data_path, tokenizer, slot_meta,
max_seq_length, n_history=1, diag_level=False, op_code='4'):
dials = json.load(open(data_path))
data = []
domain_counter = {}
max_resp_len, max_value_len = 0, 0
max_line = None
c = 0
for i, dial_dict in enumerate(dials):
if (i+1) % 200 == 0:
print("prepare {:}/{:}".format(i+1, len(dials)))
sys.stdout.flush()
for domain in dial_dict["domains"]:
if domain not in EXPERIMENT_DOMAINS:
continue
if domain not in domain_counter.keys():
domain_counter[domain] = 0
domain_counter[domain] += 1
dialog_history = []
last_dialog_state = {}
last_uttr = ""
prev_turn_state = []
ti = 0
turn_id = 0
while ti < len(dial_dict["dialogue"]):
turn = dial_dict["dialogue"][ti]
if ti == 0 and turn["role"] == "user":
turn2 = turn
turn_uttr = ("" + ' ; ' + turn2["text"]).strip()
elif (ti + 1) == len(dial_dict["dialogue"]):
turn2 = {"state": []}
turn_uttr = (turn["text"] + ' ; ' + "").strip()
else:
ti += 1
turn2 = dial_dict["dialogue"][ti]
turn_uttr = (turn["text"] + ' ; ' + turn2["text"]).strip()
dialog_history.append(last_uttr)
if turn2["state"]:
turn_dialog_state = [d_s_v.split("-") for d_s_v in turn2["state"]]
turn_dialog_state = {d_s_v[0] + "-" + d_s_v[1]: d_s_v[2] for d_s_v in turn_dialog_state}
turn_domain = turn2["state"][-1].split("-")[0]
prev_turn_state = turn2["state"]
elif prev_turn_state:
turn_dialog_state = [d_s_v.split("-") for d_s_v in prev_turn_state]
turn_dialog_state = {d_s_v[0] + "-" + d_s_v[1]: d_s_v[2] for d_s_v in turn_dialog_state}
turn_domain = prev_turn_state[-1].split("-")[0]
else:
ti+=1
continue
last_uttr = turn_uttr
op_labels, generate_y, gold_state = make_turn_label(slot_meta, last_dialog_state,
turn_dialog_state,
tokenizer, op_code)
if (ti + 1) == len(dial_dict["dialogue"]):
is_last_turn = True
else:
is_last_turn = False
# turn_uttr = tokenizer.tokenize(turn_uttr)
# dial_his = tokenizer.tokenize(' '.join(dialog_history[-n_history:]))
instance = TrainingInstance(dial_dict["dialogue_idx"], turn_domain,
turn_id, turn_uttr, ' '.join(dialog_history[-n_history:]),
last_dialog_state, op_labels,
generate_y, gold_state, max_seq_length, slot_meta,
is_last_turn, op_code=op_code)
instance.make_instance(tokenizer)
data.append(instance)
c += 1
turn_id += 1
ti += 1
last_dialog_state = turn_dialog_state
return data
class WOSDataset(Dataset):
def __init__(self, data, tokenizer, slot_meta, max_seq_length, rng,
ontology, word_dropout=0.1, shuffle_state=False, shuffle_p=0.5):
self.data = data
self.len = len(data)
self.tokenizer = tokenizer
self.slot_meta = slot_meta
self.max_seq_length = max_seq_length
self.ontology = ontology
self.word_dropout = word_dropout
self.shuffle_state = shuffle_state
self.shuffle_p = shuffle_p
self.rng = rng
def __len__(self):
return self.len
def __getitem__(self, idx):
if self.shuffle_state and self.shuffle_p > 0.:
if self.rng.random() < self.shuffle_p:
self.data[idx].shuffle_state(self.rng, None)
else:
self.data[idx].shuffle_state(self.rng, self.slot_meta)
if self.word_dropout > 0 or self.shuffle_state:
self.data[idx].make_instance(self.tokenizer,
word_dropout=self.word_dropout)
return self.data[idx]
def collate_fn(self, batch):
input_ids = torch.tensor([f.input_id for f in batch], dtype=torch.long)
input_mask = torch.tensor([f.input_mask for f in batch], dtype=torch.long)
segment_ids = torch.tensor([f.segment_id for f in batch], dtype=torch.long)
state_position_ids = torch.tensor([f.slot_position for f in batch], dtype=torch.long)
op_ids = torch.tensor([f.op_ids for f in batch], dtype=torch.long)
domain_ids = torch.tensor([f.domain_id for f in batch], dtype=torch.long)
gen_ids = [b.generate_ids for b in batch]
try:
max_update = max([len(b) for b in gen_ids])
except:
print(f"max_update error: {len(gen_ids)}")
max_update = 0
try:
max_value = max([len(b) for b in flatten(gen_ids)])
except:
print(f"max_value error: {len(gen_ids)}")
max_value = 0
for bid, b in enumerate(gen_ids):
n_update = len(b)
for idx, v in enumerate(b):
b[idx] = v + [0] * (max_value - len(v))
gen_ids[bid] = b + [[0] * max_value] * (max_update - n_update)
gen_ids = torch.tensor(gen_ids, dtype=torch.long)
return input_ids, input_mask, segment_ids, state_position_ids, op_ids, domain_ids, gen_ids, max_value, max_update
inference를 위한 데이터 준비과정
def prepare_eval_dataset(data_path, tokenizer, slot_meta,
max_seq_length, n_history=1, diag_level=False, op_code='4'):
dials = json.load(open(data_path))
data = []
domain_counter = {}
max_resp_len, max_value_len = 0, 0
max_line = None
c = 0
for i, dial_dict in enumerate(dials):
if (i+1) % 200 == 0:
print("prepare {:}/{:}".format(i+1, len(dials)))
sys.stdout.flush()
dom = []
for domain in dial_dict["domains"]:
if domain not in EXPERIMENT_DOMAINS:
continue
if domain not in domain_counter.keys():
domain_counter[domain] = 0
domain_counter[domain] += 1
dom.append(domain)
dialog_history = []
last_dialog_state = {}
last_uttr = ""
prev_turn_state = []
ti = 0
turn_id = 0
dial_dict = add_state(dial_dict, dom, slot_meta)
while ti < len(dial_dict["dialogue"]):
turn = dial_dict["dialogue"][ti]
if ti == 0 and turn["role"] == "user":
turn2 = turn
turn_uttr = ("" + ' ; ' + turn2["text"]).strip()
elif (ti + 1) == len(dial_dict["dialogue"]):
turn2 = {"state": []}
turn_uttr = (turn["text"] + ' ; ' + "").strip()
else:
ti += 1
turn2 = dial_dict["dialogue"][ti]
turn_uttr = (turn["text"] + ' ; ' + turn2["text"]).strip()
dialog_history.append(last_uttr)
if turn2["state"]:
turn_dialog_state = [d_s_v.split("-") for d_s_v in turn2["state"]]
turn_dialog_state = {d_s_v[0] + "-" + d_s_v[1]: d_s_v[2] for d_s_v in turn_dialog_state}
turn_domain = turn2["state"][-1].split("-")[0]
prev_turn_state = turn2["state"]
elif prev_turn_state:
turn_dialog_state = [d_s_v.split("-") for d_s_v in prev_turn_state]
turn_dialog_state = {d_s_v[0] + "-" + d_s_v[1]: d_s_v[2] for d_s_v in turn_dialog_state}
turn_domain = prev_turn_state[-1].split("-")[0]
else:
ti+=1
continue
last_uttr = turn_uttr
op_labels, generate_y, gold_state = make_turn_label(slot_meta, last_dialog_state,
turn_dialog_state,
tokenizer, op_code)
if (ti + 1) == len(dial_dict["dialogue"]):
is_last_turn = True
else:
is_last_turn = False
# turn_uttr = tokenizer.tokenize(turn_uttr)
# dial_his = tokenizer.tokenize(' '.join(dialog_history[-n_history:]))
instance = TrainingInstance(dial_dict["dialogue_idx"], turn_domain,
turn_id, turn_uttr, ' '.join(dialog_history[-n_history:]),
last_dialog_state, op_labels,
generate_y, gold_state, max_seq_length, slot_meta,
is_last_turn, op_code=op_code)
instance.make_instance(tokenizer)
data.append(instance)
c += 1
turn_id += 1
ti += 1
last_dialog_state = turn_dialog_state
return data
def add_state(dial_dict, dom, slot_meta):
tmp = []
for slot in slot_meta:
for x in dom:
if x in slot:
tmp.append(slot + "-none")
for dial in dial_dict["dialogue"]:
dial["state"] = tmp
return dial_dict
def inference(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
is_gt_op=False, is_gt_p_state=False, is_gt_gen=False):
model.eval()
op2id = OP_SET[op_code]
id2op = {v: k for k, v in op2id.items()}
id2domain = {v: k for k, v in domain2id.items()}
slot_turn_acc, joint_acc, slot_F1_pred, slot_F1_count = 0, 0, 0, 0
final_joint_acc, final_count, final_slot_F1_pred, final_slot_F1_count = 0, 0, 0, 0
op_acc, op_F1, op_F1_count = 0, {k: 0 for k in op2id}, {k: 0 for k in op2id}
all_op_F1_count = {k: 0 for k in op2id}
tp_dic = {k: 0 for k in op2id}
fn_dic = {k: 0 for k in op2id}
fp_dic = {k: 0 for k in op2id}
results = {}
last_dialog_state = {}
wall_times = []
for di, i in enumerate(tqdm(test_data)):
if i.turn_id == 0:
last_dialog_state = {}
if is_gt_p_state is False:
i.last_dialog_state = deepcopy(last_dialog_state)
i.make_instance(tokenizer, word_dropout=0.)
else: # ground-truth previous dialogue state
last_dialog_state = deepcopy(i.gold_p_state)
i.last_dialog_state = deepcopy(last_dialog_state)
i.make_instance(tokenizer, word_dropout=0.)
input_ids = torch.LongTensor([i.input_id]).to(device)
input_mask = torch.FloatTensor([i.input_mask]).to(device)
segment_ids = torch.LongTensor([i.segment_id]).to(device)
state_position_ids = torch.LongTensor([i.slot_position]).to(device)
d_gold_op, _, _ = make_turn_label(slot_meta, last_dialog_state, i.gold_state,
tokenizer, op_code, dynamic=True)
gold_op_ids = torch.LongTensor([d_gold_op]).to(device)
start = time.perf_counter()
MAX_LENGTH = 9
with torch.no_grad():
# ground-truth state operation
gold_op_inputs = gold_op_ids if is_gt_op else None
d, s, g = model(input_ids=input_ids,
token_type_ids=segment_ids,
state_positions=state_position_ids,
attention_mask=input_mask,
max_value=MAX_LENGTH,
op_ids=gold_op_inputs)
_, op_ids = s.view(-1, len(op2id)).max(-1)
if g.size(1) > 0:
generated = g.squeeze(0).max(-1)[1].tolist()
else:
generated = []
if is_gt_op:
pred_ops = [id2op[a] for a in gold_op_ids[0].tolist()]
else:
pred_ops = [id2op[a] for a in op_ids.tolist()]
gold_ops = [id2op[a] for a in d_gold_op]
if is_gt_gen:
# ground_truth generation
gold_gen = {'-'.join(ii.split('-')[:2]): ii.split('-')[-1] for ii in i.gold_state}
else:
gold_gen = {}
generated, last_dialog_state = postprocessing(slot_meta, pred_ops, last_dialog_state,
generated, tokenizer, op_code, gold_gen)
end = time.perf_counter()
wall_times.append(end - start)
pred_state = []
for k, v in last_dialog_state.items():
pred_state.append('-'.join([k, v]))
key = str(i.id) + '-' + str(i.turn_id)
results[key] = [pred_state, i.gold_state]
return results