Commit c9f71cb2 by Gencer

fix

parent 6b8c39de
{
"original_labels":{
"full_labels":{
"Continuous urban fabric": 0,
"Discontinuous urban fabric": 1,
"Industrial or commercial units": 2,
......@@ -44,7 +44,7 @@
"Estuaries": 41,
"Sea and ocean": 42
},
"label_update": [
"label_conversion": [
[0, 1],
[2],
[11, 12, 13],
......@@ -65,7 +65,7 @@
[38, 39],
[40, 41, 42]
],
"updated_labels":{
"compact_labels":{
"Urban fabric": 0,
"Industrial or commercial units": 1,
"Arable land": 2,
......
......@@ -31,11 +31,11 @@ UPDATE_JSON = False
with open('label_indices.json', 'rb') as f:
label_indices = json.load(f)
label_update = label_indices['label_update']
updated_label_idx = {v: k for k, v in label_indices['updated_labels'].iteritems()}
label_conversion = label_indices['label_conversion']
compact_label_idx = {v: k for k, v in label_indices['compact_labels'].iteritems()}
def prep_example(bands, original_labels, updated_labels, original_multi_hot,
updated_multi_hot,
def prep_example(bands, full_labels, compact_labels, full_labels_multi_hot,
compact_labels_multi_hot,
patch_name):
return tf.train.Example(
features=tf.train.Features(
......@@ -64,16 +64,16 @@ def prep_example(bands, original_labels, updated_labels, original_multi_hot,
int64_list=tf.train.Int64List(value=np.ravel(bands['B11']))),
'B12': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B12']))),
'original_labels': tf.train.Feature(
'full_labels': tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[i.encode('utf-8') for i in original_labels])),
'updated_labels': tf.train.Feature(
value=[i.encode('utf-8') for i in full_labels])),
'compact_labels': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[
i.encode('utf-8') for i in updated_labels])),
'original_multi_hot': tf.train.Feature(
int64_list=tf.train.Int64List(value=original_multi_hot)),
'updated_multi_hot': tf.train.Feature(
int64_list=tf.train.Int64List(value=updated_multi_hot)),
i.encode('utf-8') for i in compact_labels])),
'full_labels_multi_hot': tf.train.Feature(
int64_list=tf.train.Int64List(value=full_labels_multi_hot)),
'compact_labels_multi_hot': tf.train.Feature(
int64_list=tf.train.Int64List(value=compact_labels_multi_hot)),
'patch_name': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[patch_name.encode('utf-8')]))
}))
......@@ -97,39 +97,39 @@ def create_split(root_folder, patch_names, TFRecord_writer):
band_data = np.array(band_ds.read(1))
bands[band_name] = np.array(band_data)
original_multi_hot = np.zeros(
len(label_indices['original_labels'].keys()), dtype=int)
updated_multi_hot = np.zeros(len(label_update),dtype=int)
full_labels_multi_hot = np.zeros(
len(label_indices['full_labels'].keys()), dtype=int)
compact_labels_multi_hot = np.zeros(len(label_conversion),dtype=int)
patch_json_path = os.path.join(
patch_folder_path, patch_name + '_labels_metadata.json')
with open(patch_json_path, 'rb') as f:
patch_json = json.load(f)
original_labels = patch_json['labels']
for label in original_labels:
original_multi_hot[label_indices['original_labels'][label]] = 1
full_labels = patch_json['labels']
for label in full_labels:
full_labels_multi_hot[label_indices['full_labels'][label]] = 1
for i in range(len(label_update)):
updated_multi_hot[i] = (
np.sum(original_multi_hot[label_update[i]]) > 0
for i in range(len(label_conversion)):
compact_labels_multi_hot[i] = (
np.sum(full_labels_multi_hot[label_conversion[i]]) > 0
).astype(int)
updated_labels = []
for i in np.where(updated_multi_hot == 1)[0]:
updated_labels.append(updated_label_idx[i])
compact_labels = []
for i in np.where(compact_labels_multi_hot == 1)[0]:
compact_labels.append(compact_label_idx[i])
if UPDATE_JSON:
patch_json['updated_labels'] = updated_labels
patch_json['compact_labels'] = compact_labels
with open(patch_json_path, 'wb') as f:
json.dump(patch_json, f)
example = prep_example(
bands,
original_labels,
updated_labels,
original_multi_hot,
updated_multi_hot,
full_labels,
compact_labels,
full_labels_multi_hot,
compact_labels_multi_hot,
patch_name
)
TFRecord_writer.write(example.SerializeToString())
......@@ -144,7 +144,7 @@ if __name__ == "__main__":
parser.add_argument('-o', '--out_folder', dest = 'out_folder',
help = 'folder path containing resulting TFRecord files')
parser.add_argument('--update_json', default = False, action = "store_true", help =
'flag for adding updated label to the json file of each patch')
'flag for adding compact label to the json file of each patch')
parser.add_argument('-n', '--patch_names', dest = 'patch_names', help =
'csv files each of which contain list of patch names', nargs = '+')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment